• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test the support for SSL and sockets
2
3import sys
4import unittest
5import unittest.mock
6from test import support
7from test.support import import_helper
8from test.support import os_helper
9from test.support import socket_helper
10from test.support import threading_helper
11from test.support import warnings_helper
12import socket
13import select
14import time
15import datetime
16import gc
17import os
18import errno
19import pprint
20import urllib.request
21import threading
22import traceback
23import weakref
24import platform
25import sysconfig
26import functools
27try:
28    import ctypes
29except ImportError:
30    ctypes = None
31
32import warnings
33with warnings.catch_warnings():
34    warnings.simplefilter('ignore', DeprecationWarning)
35    import asyncore
36
37ssl = import_helper.import_module("ssl")
38import _ssl
39
40from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType
41
42Py_DEBUG = hasattr(sys, 'gettotalrefcount')
43Py_DEBUG_WIN32 = Py_DEBUG and sys.platform == 'win32'
44
45PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
46HOST = socket_helper.HOST
47IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
48PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
49
50PROTOCOL_TO_TLS_VERSION = {}
51for proto, ver in (
52    ("PROTOCOL_SSLv23", "SSLv3"),
53    ("PROTOCOL_TLSv1", "TLSv1"),
54    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
55):
56    try:
57        proto = getattr(ssl, proto)
58        ver = getattr(ssl.TLSVersion, ver)
59    except AttributeError:
60        continue
61    PROTOCOL_TO_TLS_VERSION[proto] = ver
62
63def data_file(*name):
64    return os.path.join(os.path.dirname(__file__), *name)
65
66# The custom key and certificate files used in test_ssl are generated
67# using Lib/test/make_ssl_certs.py.
68# Other certificates are simply fetched from the internet servers they
69# are meant to authenticate.
70
71CERTFILE = data_file("keycert.pem")
72BYTES_CERTFILE = os.fsencode(CERTFILE)
73ONLYCERT = data_file("ssl_cert.pem")
74ONLYKEY = data_file("ssl_key.pem")
75BYTES_ONLYCERT = os.fsencode(ONLYCERT)
76BYTES_ONLYKEY = os.fsencode(ONLYKEY)
77CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
78ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
79KEY_PASSWORD = "somepass"
80CAPATH = data_file("capath")
81BYTES_CAPATH = os.fsencode(CAPATH)
82CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
83CAFILE_CACERT = data_file("capath", "5ed36f99.0")
84
85CERTFILE_INFO = {
86    'issuer': ((('countryName', 'XY'),),
87               (('localityName', 'Castle Anthrax'),),
88               (('organizationName', 'Python Software Foundation'),),
89               (('commonName', 'localhost'),)),
90    'notAfter': 'Aug 26 14:23:15 2028 GMT',
91    'notBefore': 'Aug 29 14:23:15 2018 GMT',
92    'serialNumber': '98A7CF88C74A32ED',
93    'subject': ((('countryName', 'XY'),),
94             (('localityName', 'Castle Anthrax'),),
95             (('organizationName', 'Python Software Foundation'),),
96             (('commonName', 'localhost'),)),
97    'subjectAltName': (('DNS', 'localhost'),),
98    'version': 3
99}
100
101# empty CRL
102CRLFILE = data_file("revocation.crl")
103
104# Two keys and certs signed by the same CA (for SNI tests)
105SIGNED_CERTFILE = data_file("keycert3.pem")
106SIGNED_CERTFILE_HOSTNAME = 'localhost'
107
108SIGNED_CERTFILE_INFO = {
109    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
110    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
111    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
112    'issuer': ((('countryName', 'XY'),),
113            (('organizationName', 'Python Software Foundation CA'),),
114            (('commonName', 'our-ca-server'),)),
115    'notAfter': 'Oct 28 14:23:16 2037 GMT',
116    'notBefore': 'Aug 29 14:23:16 2018 GMT',
117    'serialNumber': 'CB2D80995A69525C',
118    'subject': ((('countryName', 'XY'),),
119             (('localityName', 'Castle Anthrax'),),
120             (('organizationName', 'Python Software Foundation'),),
121             (('commonName', 'localhost'),)),
122    'subjectAltName': (('DNS', 'localhost'),),
123    'version': 3
124}
125
126SIGNED_CERTFILE2 = data_file("keycert4.pem")
127SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
128SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
129SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
130
131# Same certificate as pycacert.pem, but without extra text in file
132SIGNING_CA = data_file("capath", "ceff1710.0")
133# cert with all kinds of subject alt names
134ALLSANFILE = data_file("allsans.pem")
135IDNSANSFILE = data_file("idnsans.pem")
136NOSANFILE = data_file("nosan.pem")
137NOSAN_HOSTNAME = 'localhost'
138
139REMOTE_HOST = "self-signed.pythontest.net"
140
141EMPTYCERT = data_file("nullcert.pem")
142BADCERT = data_file("badcert.pem")
143NONEXISTINGCERT = data_file("XXXnonexisting.pem")
144BADKEY = data_file("badkey.pem")
145NOKIACERT = data_file("nokia.pem")
146NULLBYTECERT = data_file("nullbytecert.pem")
147TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
148
149DHFILE = data_file("ffdh3072.pem")
150BYTES_DHFILE = os.fsencode(DHFILE)
151
152# Not defined in all versions of OpenSSL
153OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
154OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
155OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
156OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
157OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
158OP_IGNORE_UNEXPECTED_EOF = getattr(ssl, "OP_IGNORE_UNEXPECTED_EOF", 0)
159
160# Ubuntu has patched OpenSSL and changed behavior of security level 2
161# see https://bugs.python.org/issue41561#msg389003
162def is_ubuntu():
163    try:
164        # Assume that any references of "ubuntu" implies Ubuntu-like distro
165        # The workaround is not required for 18.04, but doesn't hurt either.
166        with open("/etc/os-release", encoding="utf-8") as f:
167            return "ubuntu" in f.read()
168    except FileNotFoundError:
169        return False
170
171if is_ubuntu():
172    def seclevel_workaround(*ctxs):
173        """"Lower security level to '1' and allow all ciphers for TLS 1.0/1"""
174        for ctx in ctxs:
175            if (
176                hasattr(ctx, "minimum_version") and
177                ctx.minimum_version <= ssl.TLSVersion.TLSv1_1
178            ):
179                ctx.set_ciphers("@SECLEVEL=1:ALL")
180else:
181    def seclevel_workaround(*ctxs):
182        pass
183
184
185def has_tls_protocol(protocol):
186    """Check if a TLS protocol is available and enabled
187
188    :param protocol: enum ssl._SSLMethod member or name
189    :return: bool
190    """
191    if isinstance(protocol, str):
192        assert protocol.startswith('PROTOCOL_')
193        protocol = getattr(ssl, protocol, None)
194        if protocol is None:
195            return False
196    if protocol in {
197        ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER,
198        ssl.PROTOCOL_TLS_CLIENT
199    }:
200        # auto-negotiate protocols are always available
201        return True
202    name = protocol.name
203    return has_tls_version(name[len('PROTOCOL_'):])
204
205
206@functools.lru_cache
207def has_tls_version(version):
208    """Check if a TLS/SSL version is enabled
209
210    :param version: TLS version name or ssl.TLSVersion member
211    :return: bool
212    """
213    if version == "SSLv2":
214        # never supported and not even in TLSVersion enum
215        return False
216
217    if isinstance(version, str):
218        version = ssl.TLSVersion.__members__[version]
219
220    # check compile time flags like ssl.HAS_TLSv1_2
221    if not getattr(ssl, f'HAS_{version.name}'):
222        return False
223
224    if IS_OPENSSL_3_0_0 and version < ssl.TLSVersion.TLSv1_2:
225        # bpo43791: 3.0.0-alpha14 fails with TLSV1_ALERT_INTERNAL_ERROR
226        return False
227
228    # check runtime and dynamic crypto policy settings. A TLS version may
229    # be compiled in but disabled by a policy or config option.
230    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
231    if (
232            hasattr(ctx, 'minimum_version') and
233            ctx.minimum_version != ssl.TLSVersion.MINIMUM_SUPPORTED and
234            version < ctx.minimum_version
235    ):
236        return False
237    if (
238        hasattr(ctx, 'maximum_version') and
239        ctx.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
240        version > ctx.maximum_version
241    ):
242        return False
243
244    return True
245
246
247def requires_tls_version(version):
248    """Decorator to skip tests when a required TLS version is not available
249
250    :param version: TLS version name or ssl.TLSVersion member
251    :return:
252    """
253    def decorator(func):
254        @functools.wraps(func)
255        def wrapper(*args, **kw):
256            if not has_tls_version(version):
257                raise unittest.SkipTest(f"{version} is not available.")
258            else:
259                return func(*args, **kw)
260        return wrapper
261    return decorator
262
263
264def handle_error(prefix):
265    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
266    if support.verbose:
267        sys.stdout.write(prefix + exc_format)
268
269
270def utc_offset(): #NOTE: ignore issues like #1647654
271    # local time = utc time + utc offset
272    if time.daylight and time.localtime().tm_isdst > 0:
273        return -time.altzone  # seconds
274    return -time.timezone
275
276
277ignore_deprecation = warnings_helper.ignore_warnings(
278    category=DeprecationWarning
279)
280
281
282def test_wrap_socket(sock, *,
283                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
284                     ciphers=None, certfile=None, keyfile=None,
285                     **kwargs):
286    if not kwargs.get("server_side"):
287        kwargs["server_hostname"] = SIGNED_CERTFILE_HOSTNAME
288        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
289    else:
290        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
291    if cert_reqs is not None:
292        if cert_reqs == ssl.CERT_NONE:
293            context.check_hostname = False
294        context.verify_mode = cert_reqs
295    if ca_certs is not None:
296        context.load_verify_locations(ca_certs)
297    if certfile is not None or keyfile is not None:
298        context.load_cert_chain(certfile, keyfile)
299    if ciphers is not None:
300        context.set_ciphers(ciphers)
301    return context.wrap_socket(sock, **kwargs)
302
303
304def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True):
305    """Create context
306
307    client_context, server_context, hostname = testing_context()
308    """
309    if server_cert == SIGNED_CERTFILE:
310        hostname = SIGNED_CERTFILE_HOSTNAME
311    elif server_cert == SIGNED_CERTFILE2:
312        hostname = SIGNED_CERTFILE2_HOSTNAME
313    elif server_cert == NOSANFILE:
314        hostname = NOSAN_HOSTNAME
315    else:
316        raise ValueError(server_cert)
317
318    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
319    client_context.load_verify_locations(SIGNING_CA)
320
321    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
322    server_context.load_cert_chain(server_cert)
323    if server_chain:
324        server_context.load_verify_locations(SIGNING_CA)
325
326    return client_context, server_context, hostname
327
328
329class BasicSocketTests(unittest.TestCase):
330
331    def test_constants(self):
332        ssl.CERT_NONE
333        ssl.CERT_OPTIONAL
334        ssl.CERT_REQUIRED
335        ssl.OP_CIPHER_SERVER_PREFERENCE
336        ssl.OP_SINGLE_DH_USE
337        ssl.OP_SINGLE_ECDH_USE
338        ssl.OP_NO_COMPRESSION
339        self.assertEqual(ssl.HAS_SNI, True)
340        self.assertEqual(ssl.HAS_ECDH, True)
341        self.assertEqual(ssl.HAS_TLSv1_2, True)
342        self.assertEqual(ssl.HAS_TLSv1_3, True)
343        ssl.OP_NO_SSLv2
344        ssl.OP_NO_SSLv3
345        ssl.OP_NO_TLSv1
346        ssl.OP_NO_TLSv1_3
347        ssl.OP_NO_TLSv1_1
348        ssl.OP_NO_TLSv1_2
349        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
350
351    def test_ssl_types(self):
352        ssl_types = [
353            _ssl._SSLContext,
354            _ssl._SSLSocket,
355            _ssl.MemoryBIO,
356            _ssl.Certificate,
357            _ssl.SSLSession,
358            _ssl.SSLError,
359        ]
360        for ssl_type in ssl_types:
361            with self.subTest(ssl_type=ssl_type):
362                with self.assertRaisesRegex(TypeError, "immutable type"):
363                    ssl_type.value = None
364        support.check_disallow_instantiation(self, _ssl.Certificate)
365
366    def test_private_init(self):
367        with self.assertRaisesRegex(TypeError, "public constructor"):
368            with socket.socket() as s:
369                ssl.SSLSocket(s)
370
371    def test_str_for_enums(self):
372        # Make sure that the PROTOCOL_* constants have enum-like string
373        # reprs.
374        proto = ssl.PROTOCOL_TLS_CLIENT
375        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS_CLIENT')
376        ctx = ssl.SSLContext(proto)
377        self.assertIs(ctx.protocol, proto)
378
379    def test_random(self):
380        v = ssl.RAND_status()
381        if support.verbose:
382            sys.stdout.write("\n RAND_status is %d (%s)\n"
383                             % (v, (v and "sufficient randomness") or
384                                "insufficient randomness"))
385
386        with warnings_helper.check_warnings():
387            data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
388        self.assertEqual(len(data), 16)
389        self.assertEqual(is_cryptographic, v == 1)
390        if v:
391            data = ssl.RAND_bytes(16)
392            self.assertEqual(len(data), 16)
393        else:
394            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
395
396        # negative num is invalid
397        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
398        with warnings_helper.check_warnings():
399            self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
400
401        ssl.RAND_add("this is a random string", 75.0)
402        ssl.RAND_add(b"this is a random bytes object", 75.0)
403        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
404
405    def test_parse_cert(self):
406        # note that this uses an 'unofficial' function in _ssl.c,
407        # provided solely for this test, to exercise the certificate
408        # parsing code
409        self.assertEqual(
410            ssl._ssl._test_decode_cert(CERTFILE),
411            CERTFILE_INFO
412        )
413        self.assertEqual(
414            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
415            SIGNED_CERTFILE_INFO
416        )
417
418        # Issue #13034: the subjectAltName in some certificates
419        # (notably projects.developer.nokia.com:443) wasn't parsed
420        p = ssl._ssl._test_decode_cert(NOKIACERT)
421        if support.verbose:
422            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
423        self.assertEqual(p['subjectAltName'],
424                         (('DNS', 'projects.developer.nokia.com'),
425                          ('DNS', 'projects.forum.nokia.com'))
426                        )
427        # extra OCSP and AIA fields
428        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
429        self.assertEqual(p['caIssuers'],
430                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
431        self.assertEqual(p['crlDistributionPoints'],
432                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
433
434    def test_parse_cert_CVE_2019_5010(self):
435        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
436        if support.verbose:
437            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
438        self.assertEqual(
439            p,
440            {
441                'issuer': (
442                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
443                'notAfter': 'Jun 14 18:00:58 2028 GMT',
444                'notBefore': 'Jun 18 18:00:58 2018 GMT',
445                'serialNumber': '02',
446                'subject': ((('countryName', 'UK'),),
447                            (('commonName',
448                              'codenomicon-vm-2.test.lal.cisco.com'),)),
449                'subjectAltName': (
450                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
451                'version': 3
452            }
453        )
454
455    def test_parse_cert_CVE_2013_4238(self):
456        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
457        if support.verbose:
458            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
459        subject = ((('countryName', 'US'),),
460                   (('stateOrProvinceName', 'Oregon'),),
461                   (('localityName', 'Beaverton'),),
462                   (('organizationName', 'Python Software Foundation'),),
463                   (('organizationalUnitName', 'Python Core Development'),),
464                   (('commonName', 'null.python.org\x00example.org'),),
465                   (('emailAddress', 'python-dev@python.org'),))
466        self.assertEqual(p['subject'], subject)
467        self.assertEqual(p['issuer'], subject)
468        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
469            san = (('DNS', 'altnull.python.org\x00example.com'),
470                   ('email', 'null@python.org\x00user@example.org'),
471                   ('URI', 'http://null.python.org\x00http://example.org'),
472                   ('IP Address', '192.0.2.1'),
473                   ('IP Address', '2001:DB8:0:0:0:0:0:1'))
474        else:
475            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
476            san = (('DNS', 'altnull.python.org\x00example.com'),
477                   ('email', 'null@python.org\x00user@example.org'),
478                   ('URI', 'http://null.python.org\x00http://example.org'),
479                   ('IP Address', '192.0.2.1'),
480                   ('IP Address', '<invalid>'))
481
482        self.assertEqual(p['subjectAltName'], san)
483
484    def test_parse_all_sans(self):
485        p = ssl._ssl._test_decode_cert(ALLSANFILE)
486        self.assertEqual(p['subjectAltName'],
487            (
488                ('DNS', 'allsans'),
489                ('othername', '<unsupported>'),
490                ('othername', '<unsupported>'),
491                ('email', 'user@example.org'),
492                ('DNS', 'www.example.org'),
493                ('DirName',
494                    ((('countryName', 'XY'),),
495                    (('localityName', 'Castle Anthrax'),),
496                    (('organizationName', 'Python Software Foundation'),),
497                    (('commonName', 'dirname example'),))),
498                ('URI', 'https://www.python.org/'),
499                ('IP Address', '127.0.0.1'),
500                ('IP Address', '0:0:0:0:0:0:0:1'),
501                ('Registered ID', '1.2.3.4.5')
502            )
503        )
504
505    def test_DER_to_PEM(self):
506        with open(CAFILE_CACERT, 'r') as f:
507            pem = f.read()
508        d1 = ssl.PEM_cert_to_DER_cert(pem)
509        p2 = ssl.DER_cert_to_PEM_cert(d1)
510        d2 = ssl.PEM_cert_to_DER_cert(p2)
511        self.assertEqual(d1, d2)
512        if not p2.startswith(ssl.PEM_HEADER + '\n'):
513            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
514        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
515            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
516
517    def test_openssl_version(self):
518        n = ssl.OPENSSL_VERSION_NUMBER
519        t = ssl.OPENSSL_VERSION_INFO
520        s = ssl.OPENSSL_VERSION
521        self.assertIsInstance(n, int)
522        self.assertIsInstance(t, tuple)
523        self.assertIsInstance(s, str)
524        # Some sanity checks follow
525        # >= 1.1.1
526        self.assertGreaterEqual(n, 0x10101000)
527        # < 4.0
528        self.assertLess(n, 0x40000000)
529        major, minor, fix, patch, status = t
530        self.assertGreaterEqual(major, 1)
531        self.assertLess(major, 4)
532        self.assertGreaterEqual(minor, 0)
533        self.assertLess(minor, 256)
534        self.assertGreaterEqual(fix, 0)
535        self.assertLess(fix, 256)
536        self.assertGreaterEqual(patch, 0)
537        self.assertLessEqual(patch, 63)
538        self.assertGreaterEqual(status, 0)
539        self.assertLessEqual(status, 15)
540
541        libressl_ver = f"LibreSSL {major:d}"
542        if major >= 3:
543            # 3.x uses 0xMNN00PP0L
544            openssl_ver = f"OpenSSL {major:d}.{minor:d}.{patch:d}"
545        else:
546            openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
547        self.assertTrue(
548            s.startswith((openssl_ver, libressl_ver)),
549            (s, t, hex(n))
550        )
551
552    @support.cpython_only
553    def test_refcycle(self):
554        # Issue #7943: an SSL object doesn't create reference cycles with
555        # itself.
556        s = socket.socket(socket.AF_INET)
557        ss = test_wrap_socket(s)
558        wr = weakref.ref(ss)
559        with warnings_helper.check_warnings(("", ResourceWarning)):
560            del ss
561        self.assertEqual(wr(), None)
562
563    def test_wrapped_unconnected(self):
564        # Methods on an unconnected SSLSocket propagate the original
565        # OSError raise by the underlying socket object.
566        s = socket.socket(socket.AF_INET)
567        with test_wrap_socket(s) as ss:
568            self.assertRaises(OSError, ss.recv, 1)
569            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
570            self.assertRaises(OSError, ss.recvfrom, 1)
571            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
572            self.assertRaises(OSError, ss.send, b'x')
573            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
574            self.assertRaises(NotImplementedError, ss.dup)
575            self.assertRaises(NotImplementedError, ss.sendmsg,
576                              [b'x'], (), 0, ('0.0.0.0', 0))
577            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
578            self.assertRaises(NotImplementedError, ss.recvmsg_into,
579                              [bytearray(100)])
580
581    def test_timeout(self):
582        # Issue #8524: when creating an SSL socket, the timeout of the
583        # original socket should be retained.
584        for timeout in (None, 0.0, 5.0):
585            s = socket.socket(socket.AF_INET)
586            s.settimeout(timeout)
587            with test_wrap_socket(s) as ss:
588                self.assertEqual(timeout, ss.gettimeout())
589
590    def test_openssl111_deprecations(self):
591        options = [
592            ssl.OP_NO_TLSv1,
593            ssl.OP_NO_TLSv1_1,
594            ssl.OP_NO_TLSv1_2,
595            ssl.OP_NO_TLSv1_3
596        ]
597        protocols = [
598            ssl.PROTOCOL_TLSv1,
599            ssl.PROTOCOL_TLSv1_1,
600            ssl.PROTOCOL_TLSv1_2,
601            ssl.PROTOCOL_TLS
602        ]
603        versions = [
604            ssl.TLSVersion.SSLv3,
605            ssl.TLSVersion.TLSv1,
606            ssl.TLSVersion.TLSv1_1,
607        ]
608
609        for option in options:
610            with self.subTest(option=option):
611                ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
612                with self.assertWarns(DeprecationWarning) as cm:
613                    ctx.options |= option
614                self.assertEqual(
615                    'ssl.OP_NO_SSL*/ssl.OP_NO_TLS* options are deprecated',
616                    str(cm.warning)
617                )
618
619        for protocol in protocols:
620            with self.subTest(protocol=protocol):
621                with self.assertWarns(DeprecationWarning) as cm:
622                    ssl.SSLContext(protocol)
623                self.assertEqual(
624                    f'ssl.{protocol.name} is deprecated',
625                    str(cm.warning)
626                )
627
628        for version in versions:
629            with self.subTest(version=version):
630                ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
631                with self.assertWarns(DeprecationWarning) as cm:
632                    ctx.minimum_version = version
633                self.assertEqual(
634                    f'ssl.{version!s} is deprecated',
635                    str(cm.warning)
636                )
637
638    @ignore_deprecation
639    def test_errors_sslwrap(self):
640        sock = socket.socket()
641        self.assertRaisesRegex(ValueError,
642                        "certfile must be specified",
643                        ssl.wrap_socket, sock, keyfile=CERTFILE)
644        self.assertRaisesRegex(ValueError,
645                        "certfile must be specified for server-side operations",
646                        ssl.wrap_socket, sock, server_side=True)
647        self.assertRaisesRegex(ValueError,
648                        "certfile must be specified for server-side operations",
649                         ssl.wrap_socket, sock, server_side=True, certfile="")
650        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
651            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
652                                     s.connect, (HOST, 8080))
653        with self.assertRaises(OSError) as cm:
654            with socket.socket() as sock:
655                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
656        self.assertEqual(cm.exception.errno, errno.ENOENT)
657        with self.assertRaises(OSError) as cm:
658            with socket.socket() as sock:
659                ssl.wrap_socket(sock,
660                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
661        self.assertEqual(cm.exception.errno, errno.ENOENT)
662        with self.assertRaises(OSError) as cm:
663            with socket.socket() as sock:
664                ssl.wrap_socket(sock,
665                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
666        self.assertEqual(cm.exception.errno, errno.ENOENT)
667
668    def bad_cert_test(self, certfile):
669        """Check that trying to use the given client certificate fails"""
670        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
671                                   certfile)
672        sock = socket.socket()
673        self.addCleanup(sock.close)
674        with self.assertRaises(ssl.SSLError):
675            test_wrap_socket(sock,
676                             certfile=certfile)
677
678    def test_empty_cert(self):
679        """Wrapping with an empty cert file"""
680        self.bad_cert_test("nullcert.pem")
681
682    def test_malformed_cert(self):
683        """Wrapping with a badly formatted certificate (syntax error)"""
684        self.bad_cert_test("badcert.pem")
685
686    def test_malformed_key(self):
687        """Wrapping with a badly formatted key (syntax error)"""
688        self.bad_cert_test("badkey.pem")
689
690    @ignore_deprecation
691    def test_match_hostname(self):
692        def ok(cert, hostname):
693            ssl.match_hostname(cert, hostname)
694        def fail(cert, hostname):
695            self.assertRaises(ssl.CertificateError,
696                              ssl.match_hostname, cert, hostname)
697
698        # -- Hostname matching --
699
700        cert = {'subject': ((('commonName', 'example.com'),),)}
701        ok(cert, 'example.com')
702        ok(cert, 'ExAmple.cOm')
703        fail(cert, 'www.example.com')
704        fail(cert, '.example.com')
705        fail(cert, 'example.org')
706        fail(cert, 'exampleXcom')
707
708        cert = {'subject': ((('commonName', '*.a.com'),),)}
709        ok(cert, 'foo.a.com')
710        fail(cert, 'bar.foo.a.com')
711        fail(cert, 'a.com')
712        fail(cert, 'Xa.com')
713        fail(cert, '.a.com')
714
715        # only match wildcards when they are the only thing
716        # in left-most segment
717        cert = {'subject': ((('commonName', 'f*.com'),),)}
718        fail(cert, 'foo.com')
719        fail(cert, 'f.com')
720        fail(cert, 'bar.com')
721        fail(cert, 'foo.a.com')
722        fail(cert, 'bar.foo.com')
723
724        # NULL bytes are bad, CVE-2013-4073
725        cert = {'subject': ((('commonName',
726                              'null.python.org\x00example.org'),),)}
727        ok(cert, 'null.python.org\x00example.org') # or raise an error?
728        fail(cert, 'example.org')
729        fail(cert, 'null.python.org')
730
731        # error cases with wildcards
732        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
733        fail(cert, 'bar.foo.a.com')
734        fail(cert, 'a.com')
735        fail(cert, 'Xa.com')
736        fail(cert, '.a.com')
737
738        cert = {'subject': ((('commonName', 'a.*.com'),),)}
739        fail(cert, 'a.foo.com')
740        fail(cert, 'a..com')
741        fail(cert, 'a.com')
742
743        # wildcard doesn't match IDNA prefix 'xn--'
744        idna = 'püthon.python.org'.encode("idna").decode("ascii")
745        cert = {'subject': ((('commonName', idna),),)}
746        ok(cert, idna)
747        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
748        fail(cert, idna)
749        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
750        fail(cert, idna)
751
752        # wildcard in first fragment and  IDNA A-labels in sequent fragments
753        # are supported.
754        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
755        cert = {'subject': ((('commonName', idna),),)}
756        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
757        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
758        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
759        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
760
761        # Slightly fake real-world example
762        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
763                'subject': ((('commonName', 'linuxfrz.org'),),),
764                'subjectAltName': (('DNS', 'linuxfr.org'),
765                                   ('DNS', 'linuxfr.com'),
766                                   ('othername', '<unsupported>'))}
767        ok(cert, 'linuxfr.org')
768        ok(cert, 'linuxfr.com')
769        # Not a "DNS" entry
770        fail(cert, '<unsupported>')
771        # When there is a subjectAltName, commonName isn't used
772        fail(cert, 'linuxfrz.org')
773
774        # A pristine real-world example
775        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
776                'subject': ((('countryName', 'US'),),
777                            (('stateOrProvinceName', 'California'),),
778                            (('localityName', 'Mountain View'),),
779                            (('organizationName', 'Google Inc'),),
780                            (('commonName', 'mail.google.com'),))}
781        ok(cert, 'mail.google.com')
782        fail(cert, 'gmail.com')
783        # Only commonName is considered
784        fail(cert, 'California')
785
786        # -- IPv4 matching --
787        cert = {'subject': ((('commonName', 'example.com'),),),
788                'subjectAltName': (('DNS', 'example.com'),
789                                   ('IP Address', '10.11.12.13'),
790                                   ('IP Address', '14.15.16.17'),
791                                   ('IP Address', '127.0.0.1'))}
792        ok(cert, '10.11.12.13')
793        ok(cert, '14.15.16.17')
794        # socket.inet_ntoa(socket.inet_aton('127.1')) == '127.0.0.1'
795        fail(cert, '127.1')
796        fail(cert, '14.15.16.17 ')
797        fail(cert, '14.15.16.17 extra data')
798        fail(cert, '14.15.16.18')
799        fail(cert, 'example.net')
800
801        # -- IPv6 matching --
802        if socket_helper.IPV6_ENABLED:
803            cert = {'subject': ((('commonName', 'example.com'),),),
804                    'subjectAltName': (
805                        ('DNS', 'example.com'),
806                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
807                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
808            ok(cert, '2001::cafe')
809            ok(cert, '2003::baba')
810            fail(cert, '2003::baba ')
811            fail(cert, '2003::baba extra data')
812            fail(cert, '2003::bebe')
813            fail(cert, 'example.net')
814
815        # -- Miscellaneous --
816
817        # Neither commonName nor subjectAltName
818        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
819                'subject': ((('countryName', 'US'),),
820                            (('stateOrProvinceName', 'California'),),
821                            (('localityName', 'Mountain View'),),
822                            (('organizationName', 'Google Inc'),))}
823        fail(cert, 'mail.google.com')
824
825        # No DNS entry in subjectAltName but a commonName
826        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
827                'subject': ((('countryName', 'US'),),
828                            (('stateOrProvinceName', 'California'),),
829                            (('localityName', 'Mountain View'),),
830                            (('commonName', 'mail.google.com'),)),
831                'subjectAltName': (('othername', 'blabla'), )}
832        ok(cert, 'mail.google.com')
833
834        # No DNS entry subjectAltName and no commonName
835        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
836                'subject': ((('countryName', 'US'),),
837                            (('stateOrProvinceName', 'California'),),
838                            (('localityName', 'Mountain View'),),
839                            (('organizationName', 'Google Inc'),)),
840                'subjectAltName': (('othername', 'blabla'),)}
841        fail(cert, 'google.com')
842
843        # Empty cert / no cert
844        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
845        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
846
847        # Issue #17980: avoid denials of service by refusing more than one
848        # wildcard per fragment.
849        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
850        with self.assertRaisesRegex(
851                ssl.CertificateError,
852                "partial wildcards in leftmost label are not supported"):
853            ssl.match_hostname(cert, 'axxb.example.com')
854
855        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
856        with self.assertRaisesRegex(
857                ssl.CertificateError,
858                "wildcard can only be present in the leftmost label"):
859            ssl.match_hostname(cert, 'www.sub.example.com')
860
861        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
862        with self.assertRaisesRegex(
863                ssl.CertificateError,
864                "too many wildcards"):
865            ssl.match_hostname(cert, 'axxbxxc.example.com')
866
867        cert = {'subject': ((('commonName', '*'),),)}
868        with self.assertRaisesRegex(
869                ssl.CertificateError,
870                "sole wildcard without additional labels are not support"):
871            ssl.match_hostname(cert, 'host')
872
873        cert = {'subject': ((('commonName', '*.com'),),)}
874        with self.assertRaisesRegex(
875                ssl.CertificateError,
876                r"hostname 'com' doesn't match '\*.com'"):
877            ssl.match_hostname(cert, 'com')
878
879        # extra checks for _inet_paton()
880        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
881            with self.assertRaises(ValueError):
882                ssl._inet_paton(invalid)
883        for ipaddr in ['127.0.0.1', '192.168.0.1']:
884            self.assertTrue(ssl._inet_paton(ipaddr))
885        if socket_helper.IPV6_ENABLED:
886            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
887                self.assertTrue(ssl._inet_paton(ipaddr))
888
889    def test_server_side(self):
890        # server_hostname doesn't work for server sockets
891        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
892        with socket.socket() as sock:
893            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
894                              server_hostname="some.hostname")
895
896    def test_unknown_channel_binding(self):
897        # should raise ValueError for unknown type
898        s = socket.create_server(('127.0.0.1', 0))
899        c = socket.socket(socket.AF_INET)
900        c.connect(s.getsockname())
901        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
902            with self.assertRaises(ValueError):
903                ss.get_channel_binding("unknown-type")
904        s.close()
905
906    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
907                         "'tls-unique' channel binding not available")
908    def test_tls_unique_channel_binding(self):
909        # unconnected should return None for known type
910        s = socket.socket(socket.AF_INET)
911        with test_wrap_socket(s) as ss:
912            self.assertIsNone(ss.get_channel_binding("tls-unique"))
913        # the same for server-side
914        s = socket.socket(socket.AF_INET)
915        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
916            self.assertIsNone(ss.get_channel_binding("tls-unique"))
917
918    def test_dealloc_warn(self):
919        ss = test_wrap_socket(socket.socket(socket.AF_INET))
920        r = repr(ss)
921        with self.assertWarns(ResourceWarning) as cm:
922            ss = None
923            support.gc_collect()
924        self.assertIn(r, str(cm.warning.args[0]))
925
926    def test_get_default_verify_paths(self):
927        paths = ssl.get_default_verify_paths()
928        self.assertEqual(len(paths), 6)
929        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
930
931        with os_helper.EnvironmentVarGuard() as env:
932            env["SSL_CERT_DIR"] = CAPATH
933            env["SSL_CERT_FILE"] = CERTFILE
934            paths = ssl.get_default_verify_paths()
935            self.assertEqual(paths.cafile, CERTFILE)
936            self.assertEqual(paths.capath, CAPATH)
937
938    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
939    def test_enum_certificates(self):
940        self.assertTrue(ssl.enum_certificates("CA"))
941        self.assertTrue(ssl.enum_certificates("ROOT"))
942
943        self.assertRaises(TypeError, ssl.enum_certificates)
944        self.assertRaises(WindowsError, ssl.enum_certificates, "")
945
946        trust_oids = set()
947        for storename in ("CA", "ROOT"):
948            store = ssl.enum_certificates(storename)
949            self.assertIsInstance(store, list)
950            for element in store:
951                self.assertIsInstance(element, tuple)
952                self.assertEqual(len(element), 3)
953                cert, enc, trust = element
954                self.assertIsInstance(cert, bytes)
955                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
956                self.assertIsInstance(trust, (frozenset, set, bool))
957                if isinstance(trust, (frozenset, set)):
958                    trust_oids.update(trust)
959
960        serverAuth = "1.3.6.1.5.5.7.3.1"
961        self.assertIn(serverAuth, trust_oids)
962
963    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
964    def test_enum_crls(self):
965        self.assertTrue(ssl.enum_crls("CA"))
966        self.assertRaises(TypeError, ssl.enum_crls)
967        self.assertRaises(WindowsError, ssl.enum_crls, "")
968
969        crls = ssl.enum_crls("CA")
970        self.assertIsInstance(crls, list)
971        for element in crls:
972            self.assertIsInstance(element, tuple)
973            self.assertEqual(len(element), 2)
974            self.assertIsInstance(element[0], bytes)
975            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
976
977
978    def test_asn1object(self):
979        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
980                    '1.3.6.1.5.5.7.3.1')
981
982        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
983        self.assertEqual(val, expected)
984        self.assertEqual(val.nid, 129)
985        self.assertEqual(val.shortname, 'serverAuth')
986        self.assertEqual(val.longname, 'TLS Web Server Authentication')
987        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
988        self.assertIsInstance(val, ssl._ASN1Object)
989        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
990
991        val = ssl._ASN1Object.fromnid(129)
992        self.assertEqual(val, expected)
993        self.assertIsInstance(val, ssl._ASN1Object)
994        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
995        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
996            ssl._ASN1Object.fromnid(100000)
997        for i in range(1000):
998            try:
999                obj = ssl._ASN1Object.fromnid(i)
1000            except ValueError:
1001                pass
1002            else:
1003                self.assertIsInstance(obj.nid, int)
1004                self.assertIsInstance(obj.shortname, str)
1005                self.assertIsInstance(obj.longname, str)
1006                self.assertIsInstance(obj.oid, (str, type(None)))
1007
1008        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
1009        self.assertEqual(val, expected)
1010        self.assertIsInstance(val, ssl._ASN1Object)
1011        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
1012        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
1013                         expected)
1014        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
1015            ssl._ASN1Object.fromname('serverauth')
1016
1017    def test_purpose_enum(self):
1018        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
1019        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
1020        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
1021        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
1022        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
1023        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
1024                              '1.3.6.1.5.5.7.3.1')
1025
1026        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
1027        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
1028        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
1029        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
1030        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
1031        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
1032                              '1.3.6.1.5.5.7.3.2')
1033
1034    def test_unsupported_dtls(self):
1035        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1036        self.addCleanup(s.close)
1037        with self.assertRaises(NotImplementedError) as cx:
1038            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
1039        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1040        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1041        with self.assertRaises(NotImplementedError) as cx:
1042            ctx.wrap_socket(s)
1043        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1044
1045    def cert_time_ok(self, timestring, timestamp):
1046        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
1047
1048    def cert_time_fail(self, timestring):
1049        with self.assertRaises(ValueError):
1050            ssl.cert_time_to_seconds(timestring)
1051
1052    @unittest.skipUnless(utc_offset(),
1053                         'local time needs to be different from UTC')
1054    def test_cert_time_to_seconds_timezone(self):
1055        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
1056        #               results if local timezone is not UTC
1057        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
1058        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
1059
1060    def test_cert_time_to_seconds(self):
1061        timestring = "Jan  5 09:34:43 2018 GMT"
1062        ts = 1515144883.0
1063        self.cert_time_ok(timestring, ts)
1064        # accept keyword parameter, assert its name
1065        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
1066        # accept both %e and %d (space or zero generated by strftime)
1067        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
1068        # case-insensitive
1069        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
1070        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
1071        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
1072        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
1073        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
1074        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
1075        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
1076        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
1077
1078        newyear_ts = 1230768000.0
1079        # leap seconds
1080        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
1081        # same timestamp
1082        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
1083
1084        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
1085        #  allow 60th second (even if it is not a leap second)
1086        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
1087        #  allow 2nd leap second for compatibility with time.strptime()
1088        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
1089        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
1090
1091        # no special treatment for the special value:
1092        #   99991231235959Z (rfc 5280)
1093        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
1094
1095    @support.run_with_locale('LC_ALL', '')
1096    def test_cert_time_to_seconds_locale(self):
1097        # `cert_time_to_seconds()` should be locale independent
1098
1099        def local_february_name():
1100            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
1101
1102        if local_february_name().lower() == 'feb':
1103            self.skipTest("locale-specific month name needs to be "
1104                          "different from C locale")
1105
1106        # locale-independent
1107        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
1108        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
1109
1110    def test_connect_ex_error(self):
1111        server = socket.socket(socket.AF_INET)
1112        self.addCleanup(server.close)
1113        port = socket_helper.bind_port(server)  # Reserve port but don't listen
1114        s = test_wrap_socket(socket.socket(socket.AF_INET),
1115                            cert_reqs=ssl.CERT_REQUIRED)
1116        self.addCleanup(s.close)
1117        rc = s.connect_ex((HOST, port))
1118        # Issue #19919: Windows machines or VMs hosted on Windows
1119        # machines sometimes return EWOULDBLOCK.
1120        errors = (
1121            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1122            errno.EWOULDBLOCK,
1123        )
1124        self.assertIn(rc, errors)
1125
1126    def test_read_write_zero(self):
1127        # empty reads and writes now work, bpo-42854, bpo-31711
1128        client_context, server_context, hostname = testing_context()
1129        server = ThreadedEchoServer(context=server_context)
1130        with server:
1131            with client_context.wrap_socket(socket.socket(),
1132                                            server_hostname=hostname) as s:
1133                s.connect((HOST, server.port))
1134                self.assertEqual(s.recv(0), b"")
1135                self.assertEqual(s.send(b""), 0)
1136
1137
1138class ContextTests(unittest.TestCase):
1139
1140    def test_constructor(self):
1141        for protocol in PROTOCOLS:
1142            with warnings_helper.check_warnings():
1143                ctx = ssl.SSLContext(protocol)
1144            self.assertEqual(ctx.protocol, protocol)
1145        with warnings_helper.check_warnings():
1146            ctx = ssl.SSLContext()
1147        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1148        self.assertRaises(ValueError, ssl.SSLContext, -1)
1149        self.assertRaises(ValueError, ssl.SSLContext, 42)
1150
1151    def test_ciphers(self):
1152        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1153        ctx.set_ciphers("ALL")
1154        ctx.set_ciphers("DEFAULT")
1155        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1156            ctx.set_ciphers("^$:,;?*'dorothyx")
1157
1158    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1159                         "Test applies only to Python default ciphers")
1160    def test_python_ciphers(self):
1161        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1162        ciphers = ctx.get_ciphers()
1163        for suite in ciphers:
1164            name = suite['name']
1165            self.assertNotIn("PSK", name)
1166            self.assertNotIn("SRP", name)
1167            self.assertNotIn("MD5", name)
1168            self.assertNotIn("RC4", name)
1169            self.assertNotIn("3DES", name)
1170
1171    def test_get_ciphers(self):
1172        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1173        ctx.set_ciphers('AESGCM')
1174        names = set(d['name'] for d in ctx.get_ciphers())
1175        self.assertIn('AES256-GCM-SHA384', names)
1176        self.assertIn('AES128-GCM-SHA256', names)
1177
1178    def test_options(self):
1179        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1180        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1181        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1182        # SSLContext also enables these by default
1183        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1184                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1185                    OP_ENABLE_MIDDLEBOX_COMPAT |
1186                    OP_IGNORE_UNEXPECTED_EOF)
1187        self.assertEqual(default, ctx.options)
1188        with warnings_helper.check_warnings():
1189            ctx.options |= ssl.OP_NO_TLSv1
1190        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1191        with warnings_helper.check_warnings():
1192            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1193        self.assertEqual(default, ctx.options)
1194        ctx.options = 0
1195        # Ubuntu has OP_NO_SSLv3 forced on by default
1196        self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1197
1198    def test_verify_mode_protocol(self):
1199        with warnings_helper.check_warnings():
1200            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1201        # Default value
1202        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1203        ctx.verify_mode = ssl.CERT_OPTIONAL
1204        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1205        ctx.verify_mode = ssl.CERT_REQUIRED
1206        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1207        ctx.verify_mode = ssl.CERT_NONE
1208        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1209        with self.assertRaises(TypeError):
1210            ctx.verify_mode = None
1211        with self.assertRaises(ValueError):
1212            ctx.verify_mode = 42
1213
1214        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1215        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1216        self.assertFalse(ctx.check_hostname)
1217
1218        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1219        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1220        self.assertTrue(ctx.check_hostname)
1221
1222    def test_hostname_checks_common_name(self):
1223        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1224        self.assertTrue(ctx.hostname_checks_common_name)
1225        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1226            ctx.hostname_checks_common_name = True
1227            self.assertTrue(ctx.hostname_checks_common_name)
1228            ctx.hostname_checks_common_name = False
1229            self.assertFalse(ctx.hostname_checks_common_name)
1230            ctx.hostname_checks_common_name = True
1231            self.assertTrue(ctx.hostname_checks_common_name)
1232        else:
1233            with self.assertRaises(AttributeError):
1234                ctx.hostname_checks_common_name = True
1235
1236    @ignore_deprecation
1237    def test_min_max_version(self):
1238        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1239        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1240        # Fedora override the setting to TLS 1.0.
1241        minimum_range = {
1242            # stock OpenSSL
1243            ssl.TLSVersion.MINIMUM_SUPPORTED,
1244            # Fedora 29 uses TLS 1.0 by default
1245            ssl.TLSVersion.TLSv1,
1246            # RHEL 8 uses TLS 1.2 by default
1247            ssl.TLSVersion.TLSv1_2
1248        }
1249        maximum_range = {
1250            # stock OpenSSL
1251            ssl.TLSVersion.MAXIMUM_SUPPORTED,
1252            # Fedora 32 uses TLS 1.3 by default
1253            ssl.TLSVersion.TLSv1_3
1254        }
1255
1256        self.assertIn(
1257            ctx.minimum_version, minimum_range
1258        )
1259        self.assertIn(
1260            ctx.maximum_version, maximum_range
1261        )
1262
1263        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1264        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1265        self.assertEqual(
1266            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1267        )
1268        self.assertEqual(
1269            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1270        )
1271
1272        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1273        ctx.maximum_version = ssl.TLSVersion.TLSv1
1274        self.assertEqual(
1275            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1276        )
1277        self.assertEqual(
1278            ctx.maximum_version, ssl.TLSVersion.TLSv1
1279        )
1280
1281        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1282        self.assertEqual(
1283            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1284        )
1285
1286        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1287        self.assertIn(
1288            ctx.maximum_version,
1289            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3}
1290        )
1291
1292        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1293        self.assertIn(
1294            ctx.minimum_version,
1295            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1296        )
1297
1298        with self.assertRaises(ValueError):
1299            ctx.minimum_version = 42
1300
1301        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1302
1303        self.assertIn(
1304            ctx.minimum_version, minimum_range
1305        )
1306        self.assertEqual(
1307            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1308        )
1309        with self.assertRaises(ValueError):
1310            ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1311        with self.assertRaises(ValueError):
1312            ctx.maximum_version = ssl.TLSVersion.TLSv1
1313
1314
1315    @unittest.skipUnless(
1316        hasattr(ssl.SSLContext, 'security_level'),
1317        "requires OpenSSL >= 1.1.0"
1318    )
1319    def test_security_level(self):
1320        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1321        # The default security callback allows for levels between 0-5
1322        # with OpenSSL defaulting to 1, however some vendors override the
1323        # default value (e.g. Debian defaults to 2)
1324        security_level_range = {
1325            0,
1326            1, # OpenSSL default
1327            2, # Debian
1328            3,
1329            4,
1330            5,
1331        }
1332        self.assertIn(ctx.security_level, security_level_range)
1333
1334    def test_verify_flags(self):
1335        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1336        # default value
1337        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1338        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1339        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1340        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1341        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1342        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1343        ctx.verify_flags = ssl.VERIFY_DEFAULT
1344        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1345        ctx.verify_flags = ssl.VERIFY_ALLOW_PROXY_CERTS
1346        self.assertEqual(ctx.verify_flags, ssl.VERIFY_ALLOW_PROXY_CERTS)
1347        # supports any value
1348        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1349        self.assertEqual(ctx.verify_flags,
1350                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1351        with self.assertRaises(TypeError):
1352            ctx.verify_flags = None
1353
1354    def test_load_cert_chain(self):
1355        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1356        # Combined key and cert in a single file
1357        ctx.load_cert_chain(CERTFILE, keyfile=None)
1358        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1359        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1360        with self.assertRaises(OSError) as cm:
1361            ctx.load_cert_chain(NONEXISTINGCERT)
1362        self.assertEqual(cm.exception.errno, errno.ENOENT)
1363        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1364            ctx.load_cert_chain(BADCERT)
1365        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1366            ctx.load_cert_chain(EMPTYCERT)
1367        # Separate key and cert
1368        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1369        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1370        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1371        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1372        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1373            ctx.load_cert_chain(ONLYCERT)
1374        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1375            ctx.load_cert_chain(ONLYKEY)
1376        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1377            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1378        # Mismatching key and cert
1379        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1380        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1381            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1382        # Password protected key and cert
1383        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1384        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1385        ctx.load_cert_chain(CERTFILE_PROTECTED,
1386                            password=bytearray(KEY_PASSWORD.encode()))
1387        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1388        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1389        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1390                            bytearray(KEY_PASSWORD.encode()))
1391        with self.assertRaisesRegex(TypeError, "should be a string"):
1392            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1393        with self.assertRaises(ssl.SSLError):
1394            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1395        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1396            # openssl has a fixed limit on the password buffer.
1397            # PEM_BUFSIZE is generally set to 1kb.
1398            # Return a string larger than this.
1399            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1400        # Password callback
1401        def getpass_unicode():
1402            return KEY_PASSWORD
1403        def getpass_bytes():
1404            return KEY_PASSWORD.encode()
1405        def getpass_bytearray():
1406            return bytearray(KEY_PASSWORD.encode())
1407        def getpass_badpass():
1408            return "badpass"
1409        def getpass_huge():
1410            return b'a' * (1024 * 1024)
1411        def getpass_bad_type():
1412            return 9
1413        def getpass_exception():
1414            raise Exception('getpass error')
1415        class GetPassCallable:
1416            def __call__(self):
1417                return KEY_PASSWORD
1418            def getpass(self):
1419                return KEY_PASSWORD
1420        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1421        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1422        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1423        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1424        ctx.load_cert_chain(CERTFILE_PROTECTED,
1425                            password=GetPassCallable().getpass)
1426        with self.assertRaises(ssl.SSLError):
1427            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1428        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1429            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1430        with self.assertRaisesRegex(TypeError, "must return a string"):
1431            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1432        with self.assertRaisesRegex(Exception, "getpass error"):
1433            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1434        # Make sure the password function isn't called if it isn't needed
1435        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1436
1437    def test_load_verify_locations(self):
1438        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1439        ctx.load_verify_locations(CERTFILE)
1440        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1441        ctx.load_verify_locations(BYTES_CERTFILE)
1442        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1443        self.assertRaises(TypeError, ctx.load_verify_locations)
1444        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1445        with self.assertRaises(OSError) as cm:
1446            ctx.load_verify_locations(NONEXISTINGCERT)
1447        self.assertEqual(cm.exception.errno, errno.ENOENT)
1448        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1449            ctx.load_verify_locations(BADCERT)
1450        ctx.load_verify_locations(CERTFILE, CAPATH)
1451        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1452
1453        # Issue #10989: crash if the second argument type is invalid
1454        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1455
1456    def test_load_verify_cadata(self):
1457        # test cadata
1458        with open(CAFILE_CACERT) as f:
1459            cacert_pem = f.read()
1460        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1461        with open(CAFILE_NEURONIO) as f:
1462            neuronio_pem = f.read()
1463        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1464
1465        # test PEM
1466        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1467        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1468        ctx.load_verify_locations(cadata=cacert_pem)
1469        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1470        ctx.load_verify_locations(cadata=neuronio_pem)
1471        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1472        # cert already in hash table
1473        ctx.load_verify_locations(cadata=neuronio_pem)
1474        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1475
1476        # combined
1477        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1478        combined = "\n".join((cacert_pem, neuronio_pem))
1479        ctx.load_verify_locations(cadata=combined)
1480        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1481
1482        # with junk around the certs
1483        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1484        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1485                    neuronio_pem, "tail"]
1486        ctx.load_verify_locations(cadata="\n".join(combined))
1487        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1488
1489        # test DER
1490        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1491        ctx.load_verify_locations(cadata=cacert_der)
1492        ctx.load_verify_locations(cadata=neuronio_der)
1493        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1494        # cert already in hash table
1495        ctx.load_verify_locations(cadata=cacert_der)
1496        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1497
1498        # combined
1499        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1500        combined = b"".join((cacert_der, neuronio_der))
1501        ctx.load_verify_locations(cadata=combined)
1502        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1503
1504        # error cases
1505        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1506        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1507
1508        with self.assertRaisesRegex(
1509            ssl.SSLError,
1510            "no start line: cadata does not contain a certificate"
1511        ):
1512            ctx.load_verify_locations(cadata="broken")
1513        with self.assertRaisesRegex(
1514            ssl.SSLError,
1515            "not enough data: cadata does not contain a certificate"
1516        ):
1517            ctx.load_verify_locations(cadata=b"broken")
1518
1519    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1520    def test_load_dh_params(self):
1521        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1522        ctx.load_dh_params(DHFILE)
1523        if os.name != 'nt':
1524            ctx.load_dh_params(BYTES_DHFILE)
1525        self.assertRaises(TypeError, ctx.load_dh_params)
1526        self.assertRaises(TypeError, ctx.load_dh_params, None)
1527        with self.assertRaises(FileNotFoundError) as cm:
1528            ctx.load_dh_params(NONEXISTINGCERT)
1529        self.assertEqual(cm.exception.errno, errno.ENOENT)
1530        with self.assertRaises(ssl.SSLError) as cm:
1531            ctx.load_dh_params(CERTFILE)
1532
1533    def test_session_stats(self):
1534        for proto in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
1535            ctx = ssl.SSLContext(proto)
1536            self.assertEqual(ctx.session_stats(), {
1537                'number': 0,
1538                'connect': 0,
1539                'connect_good': 0,
1540                'connect_renegotiate': 0,
1541                'accept': 0,
1542                'accept_good': 0,
1543                'accept_renegotiate': 0,
1544                'hits': 0,
1545                'misses': 0,
1546                'timeouts': 0,
1547                'cache_full': 0,
1548            })
1549
1550    def test_set_default_verify_paths(self):
1551        # There's not much we can do to test that it acts as expected,
1552        # so just check it doesn't crash or raise an exception.
1553        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1554        ctx.set_default_verify_paths()
1555
1556    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1557    def test_set_ecdh_curve(self):
1558        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1559        ctx.set_ecdh_curve("prime256v1")
1560        ctx.set_ecdh_curve(b"prime256v1")
1561        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1562        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1563        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1564        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1565
1566    def test_sni_callback(self):
1567        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1568
1569        # set_servername_callback expects a callable, or None
1570        self.assertRaises(TypeError, ctx.set_servername_callback)
1571        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1572        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1573        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1574
1575        def dummycallback(sock, servername, ctx):
1576            pass
1577        ctx.set_servername_callback(None)
1578        ctx.set_servername_callback(dummycallback)
1579
1580    def test_sni_callback_refcycle(self):
1581        # Reference cycles through the servername callback are detected
1582        # and cleared.
1583        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1584        def dummycallback(sock, servername, ctx, cycle=ctx):
1585            pass
1586        ctx.set_servername_callback(dummycallback)
1587        wr = weakref.ref(ctx)
1588        del ctx, dummycallback
1589        gc.collect()
1590        self.assertIs(wr(), None)
1591
1592    def test_cert_store_stats(self):
1593        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1594        self.assertEqual(ctx.cert_store_stats(),
1595            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1596        ctx.load_cert_chain(CERTFILE)
1597        self.assertEqual(ctx.cert_store_stats(),
1598            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1599        ctx.load_verify_locations(CERTFILE)
1600        self.assertEqual(ctx.cert_store_stats(),
1601            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1602        ctx.load_verify_locations(CAFILE_CACERT)
1603        self.assertEqual(ctx.cert_store_stats(),
1604            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1605
1606    def test_get_ca_certs(self):
1607        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1608        self.assertEqual(ctx.get_ca_certs(), [])
1609        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1610        ctx.load_verify_locations(CERTFILE)
1611        self.assertEqual(ctx.get_ca_certs(), [])
1612        # but CAFILE_CACERT is a CA cert
1613        ctx.load_verify_locations(CAFILE_CACERT)
1614        self.assertEqual(ctx.get_ca_certs(),
1615            [{'issuer': ((('organizationName', 'Root CA'),),
1616                         (('organizationalUnitName', 'http://www.cacert.org'),),
1617                         (('commonName', 'CA Cert Signing Authority'),),
1618                         (('emailAddress', 'support@cacert.org'),)),
1619              'notAfter': 'Mar 29 12:29:49 2033 GMT',
1620              'notBefore': 'Mar 30 12:29:49 2003 GMT',
1621              'serialNumber': '00',
1622              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1623              'subject': ((('organizationName', 'Root CA'),),
1624                          (('organizationalUnitName', 'http://www.cacert.org'),),
1625                          (('commonName', 'CA Cert Signing Authority'),),
1626                          (('emailAddress', 'support@cacert.org'),)),
1627              'version': 3}])
1628
1629        with open(CAFILE_CACERT) as f:
1630            pem = f.read()
1631        der = ssl.PEM_cert_to_DER_cert(pem)
1632        self.assertEqual(ctx.get_ca_certs(True), [der])
1633
1634    def test_load_default_certs(self):
1635        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1636        ctx.load_default_certs()
1637
1638        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1639        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1640        ctx.load_default_certs()
1641
1642        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1643        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1644
1645        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1646        self.assertRaises(TypeError, ctx.load_default_certs, None)
1647        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1648
1649    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1650    def test_load_default_certs_env(self):
1651        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1652        with os_helper.EnvironmentVarGuard() as env:
1653            env["SSL_CERT_DIR"] = CAPATH
1654            env["SSL_CERT_FILE"] = CERTFILE
1655            ctx.load_default_certs()
1656            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1657
1658    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1659    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1660    def test_load_default_certs_env_windows(self):
1661        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1662        ctx.load_default_certs()
1663        stats = ctx.cert_store_stats()
1664
1665        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1666        with os_helper.EnvironmentVarGuard() as env:
1667            env["SSL_CERT_DIR"] = CAPATH
1668            env["SSL_CERT_FILE"] = CERTFILE
1669            ctx.load_default_certs()
1670            stats["x509"] += 1
1671            self.assertEqual(ctx.cert_store_stats(), stats)
1672
1673    def _assert_context_options(self, ctx):
1674        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1675        if OP_NO_COMPRESSION != 0:
1676            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1677                             OP_NO_COMPRESSION)
1678        if OP_SINGLE_DH_USE != 0:
1679            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1680                             OP_SINGLE_DH_USE)
1681        if OP_SINGLE_ECDH_USE != 0:
1682            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1683                             OP_SINGLE_ECDH_USE)
1684        if OP_CIPHER_SERVER_PREFERENCE != 0:
1685            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1686                             OP_CIPHER_SERVER_PREFERENCE)
1687
1688    def test_create_default_context(self):
1689        ctx = ssl.create_default_context()
1690
1691        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1692        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1693        self.assertTrue(ctx.check_hostname)
1694        self._assert_context_options(ctx)
1695
1696        with open(SIGNING_CA) as f:
1697            cadata = f.read()
1698        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1699                                         cadata=cadata)
1700        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1701        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1702        self._assert_context_options(ctx)
1703
1704        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1705        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_SERVER)
1706        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1707        self._assert_context_options(ctx)
1708
1709
1710
1711    def test__create_stdlib_context(self):
1712        ctx = ssl._create_stdlib_context()
1713        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
1714        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1715        self.assertFalse(ctx.check_hostname)
1716        self._assert_context_options(ctx)
1717
1718        with warnings_helper.check_warnings():
1719            ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1720        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1721        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1722        self._assert_context_options(ctx)
1723
1724        with warnings_helper.check_warnings():
1725            ctx = ssl._create_stdlib_context(
1726                ssl.PROTOCOL_TLSv1_2,
1727                cert_reqs=ssl.CERT_REQUIRED,
1728                check_hostname=True
1729            )
1730        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1_2)
1731        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1732        self.assertTrue(ctx.check_hostname)
1733        self._assert_context_options(ctx)
1734
1735        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1736        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_SERVER)
1737        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1738        self._assert_context_options(ctx)
1739
1740    def test_check_hostname(self):
1741        with warnings_helper.check_warnings():
1742            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1743        self.assertFalse(ctx.check_hostname)
1744        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1745
1746        # Auto set CERT_REQUIRED
1747        ctx.check_hostname = True
1748        self.assertTrue(ctx.check_hostname)
1749        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1750        ctx.check_hostname = False
1751        ctx.verify_mode = ssl.CERT_REQUIRED
1752        self.assertFalse(ctx.check_hostname)
1753        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1754
1755        # Changing verify_mode does not affect check_hostname
1756        ctx.check_hostname = False
1757        ctx.verify_mode = ssl.CERT_NONE
1758        ctx.check_hostname = False
1759        self.assertFalse(ctx.check_hostname)
1760        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1761        # Auto set
1762        ctx.check_hostname = True
1763        self.assertTrue(ctx.check_hostname)
1764        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1765
1766        ctx.check_hostname = False
1767        ctx.verify_mode = ssl.CERT_OPTIONAL
1768        ctx.check_hostname = False
1769        self.assertFalse(ctx.check_hostname)
1770        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1771        # keep CERT_OPTIONAL
1772        ctx.check_hostname = True
1773        self.assertTrue(ctx.check_hostname)
1774        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1775
1776        # Cannot set CERT_NONE with check_hostname enabled
1777        with self.assertRaises(ValueError):
1778            ctx.verify_mode = ssl.CERT_NONE
1779        ctx.check_hostname = False
1780        self.assertFalse(ctx.check_hostname)
1781        ctx.verify_mode = ssl.CERT_NONE
1782        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1783
1784    def test_context_client_server(self):
1785        # PROTOCOL_TLS_CLIENT has sane defaults
1786        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1787        self.assertTrue(ctx.check_hostname)
1788        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1789
1790        # PROTOCOL_TLS_SERVER has different but also sane defaults
1791        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1792        self.assertFalse(ctx.check_hostname)
1793        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1794
1795    def test_context_custom_class(self):
1796        class MySSLSocket(ssl.SSLSocket):
1797            pass
1798
1799        class MySSLObject(ssl.SSLObject):
1800            pass
1801
1802        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1803        ctx.sslsocket_class = MySSLSocket
1804        ctx.sslobject_class = MySSLObject
1805
1806        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1807            self.assertIsInstance(sock, MySSLSocket)
1808        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(), server_side=True)
1809        self.assertIsInstance(obj, MySSLObject)
1810
1811    def test_num_tickest(self):
1812        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1813        self.assertEqual(ctx.num_tickets, 2)
1814        ctx.num_tickets = 1
1815        self.assertEqual(ctx.num_tickets, 1)
1816        ctx.num_tickets = 0
1817        self.assertEqual(ctx.num_tickets, 0)
1818        with self.assertRaises(ValueError):
1819            ctx.num_tickets = -1
1820        with self.assertRaises(TypeError):
1821            ctx.num_tickets = None
1822
1823        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1824        self.assertEqual(ctx.num_tickets, 2)
1825        with self.assertRaises(ValueError):
1826            ctx.num_tickets = 1
1827
1828
1829class SSLErrorTests(unittest.TestCase):
1830
1831    def test_str(self):
1832        # The str() of a SSLError doesn't include the errno
1833        e = ssl.SSLError(1, "foo")
1834        self.assertEqual(str(e), "foo")
1835        self.assertEqual(e.errno, 1)
1836        # Same for a subclass
1837        e = ssl.SSLZeroReturnError(1, "foo")
1838        self.assertEqual(str(e), "foo")
1839        self.assertEqual(e.errno, 1)
1840
1841    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
1842    def test_lib_reason(self):
1843        # Test the library and reason attributes
1844        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1845        with self.assertRaises(ssl.SSLError) as cm:
1846            ctx.load_dh_params(CERTFILE)
1847        self.assertEqual(cm.exception.library, 'PEM')
1848        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1849        s = str(cm.exception)
1850        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1851
1852    def test_subclass(self):
1853        # Check that the appropriate SSLError subclass is raised
1854        # (this only tests one of them)
1855        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1856        ctx.check_hostname = False
1857        ctx.verify_mode = ssl.CERT_NONE
1858        with socket.create_server(("127.0.0.1", 0)) as s:
1859            c = socket.create_connection(s.getsockname())
1860            c.setblocking(False)
1861            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1862                with self.assertRaises(ssl.SSLWantReadError) as cm:
1863                    c.do_handshake()
1864                s = str(cm.exception)
1865                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1866                # For compatibility
1867                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1868
1869
1870    def test_bad_server_hostname(self):
1871        ctx = ssl.create_default_context()
1872        with self.assertRaises(ValueError):
1873            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1874                         server_hostname="")
1875        with self.assertRaises(ValueError):
1876            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1877                         server_hostname=".example.org")
1878        with self.assertRaises(TypeError):
1879            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1880                         server_hostname="example.org\x00evil.com")
1881
1882
1883class MemoryBIOTests(unittest.TestCase):
1884
1885    def test_read_write(self):
1886        bio = ssl.MemoryBIO()
1887        bio.write(b'foo')
1888        self.assertEqual(bio.read(), b'foo')
1889        self.assertEqual(bio.read(), b'')
1890        bio.write(b'foo')
1891        bio.write(b'bar')
1892        self.assertEqual(bio.read(), b'foobar')
1893        self.assertEqual(bio.read(), b'')
1894        bio.write(b'baz')
1895        self.assertEqual(bio.read(2), b'ba')
1896        self.assertEqual(bio.read(1), b'z')
1897        self.assertEqual(bio.read(1), b'')
1898
1899    def test_eof(self):
1900        bio = ssl.MemoryBIO()
1901        self.assertFalse(bio.eof)
1902        self.assertEqual(bio.read(), b'')
1903        self.assertFalse(bio.eof)
1904        bio.write(b'foo')
1905        self.assertFalse(bio.eof)
1906        bio.write_eof()
1907        self.assertFalse(bio.eof)
1908        self.assertEqual(bio.read(2), b'fo')
1909        self.assertFalse(bio.eof)
1910        self.assertEqual(bio.read(1), b'o')
1911        self.assertTrue(bio.eof)
1912        self.assertEqual(bio.read(), b'')
1913        self.assertTrue(bio.eof)
1914
1915    def test_pending(self):
1916        bio = ssl.MemoryBIO()
1917        self.assertEqual(bio.pending, 0)
1918        bio.write(b'foo')
1919        self.assertEqual(bio.pending, 3)
1920        for i in range(3):
1921            bio.read(1)
1922            self.assertEqual(bio.pending, 3-i-1)
1923        for i in range(3):
1924            bio.write(b'x')
1925            self.assertEqual(bio.pending, i+1)
1926        bio.read()
1927        self.assertEqual(bio.pending, 0)
1928
1929    def test_buffer_types(self):
1930        bio = ssl.MemoryBIO()
1931        bio.write(b'foo')
1932        self.assertEqual(bio.read(), b'foo')
1933        bio.write(bytearray(b'bar'))
1934        self.assertEqual(bio.read(), b'bar')
1935        bio.write(memoryview(b'baz'))
1936        self.assertEqual(bio.read(), b'baz')
1937
1938    def test_error_types(self):
1939        bio = ssl.MemoryBIO()
1940        self.assertRaises(TypeError, bio.write, 'foo')
1941        self.assertRaises(TypeError, bio.write, None)
1942        self.assertRaises(TypeError, bio.write, True)
1943        self.assertRaises(TypeError, bio.write, 1)
1944
1945
1946class SSLObjectTests(unittest.TestCase):
1947    def test_private_init(self):
1948        bio = ssl.MemoryBIO()
1949        with self.assertRaisesRegex(TypeError, "public constructor"):
1950            ssl.SSLObject(bio, bio)
1951
1952    def test_unwrap(self):
1953        client_ctx, server_ctx, hostname = testing_context()
1954        c_in = ssl.MemoryBIO()
1955        c_out = ssl.MemoryBIO()
1956        s_in = ssl.MemoryBIO()
1957        s_out = ssl.MemoryBIO()
1958        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1959        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1960
1961        # Loop on the handshake for a bit to get it settled
1962        for _ in range(5):
1963            try:
1964                client.do_handshake()
1965            except ssl.SSLWantReadError:
1966                pass
1967            if c_out.pending:
1968                s_in.write(c_out.read())
1969            try:
1970                server.do_handshake()
1971            except ssl.SSLWantReadError:
1972                pass
1973            if s_out.pending:
1974                c_in.write(s_out.read())
1975        # Now the handshakes should be complete (don't raise WantReadError)
1976        client.do_handshake()
1977        server.do_handshake()
1978
1979        # Now if we unwrap one side unilaterally, it should send close-notify
1980        # and raise WantReadError:
1981        with self.assertRaises(ssl.SSLWantReadError):
1982            client.unwrap()
1983
1984        # But server.unwrap() does not raise, because it reads the client's
1985        # close-notify:
1986        s_in.write(c_out.read())
1987        server.unwrap()
1988
1989        # And now that the client gets the server's close-notify, it doesn't
1990        # raise either.
1991        c_in.write(s_out.read())
1992        client.unwrap()
1993
1994class SimpleBackgroundTests(unittest.TestCase):
1995    """Tests that connect to a simple server running in the background"""
1996
1997    def setUp(self):
1998        self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1999        self.server_context.load_cert_chain(SIGNED_CERTFILE)
2000        server = ThreadedEchoServer(context=self.server_context)
2001        self.server_addr = (HOST, server.port)
2002        server.__enter__()
2003        self.addCleanup(server.__exit__, None, None, None)
2004
2005    def test_connect(self):
2006        with test_wrap_socket(socket.socket(socket.AF_INET),
2007                            cert_reqs=ssl.CERT_NONE) as s:
2008            s.connect(self.server_addr)
2009            self.assertEqual({}, s.getpeercert())
2010            self.assertFalse(s.server_side)
2011
2012        # this should succeed because we specify the root cert
2013        with test_wrap_socket(socket.socket(socket.AF_INET),
2014                            cert_reqs=ssl.CERT_REQUIRED,
2015                            ca_certs=SIGNING_CA) as s:
2016            s.connect(self.server_addr)
2017            self.assertTrue(s.getpeercert())
2018            self.assertFalse(s.server_side)
2019
2020    def test_connect_fail(self):
2021        # This should fail because we have no verification certs. Connection
2022        # failure crashes ThreadedEchoServer, so run this in an independent
2023        # test method.
2024        s = test_wrap_socket(socket.socket(socket.AF_INET),
2025                            cert_reqs=ssl.CERT_REQUIRED)
2026        self.addCleanup(s.close)
2027        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2028                               s.connect, self.server_addr)
2029
2030    def test_connect_ex(self):
2031        # Issue #11326: check connect_ex() implementation
2032        s = test_wrap_socket(socket.socket(socket.AF_INET),
2033                            cert_reqs=ssl.CERT_REQUIRED,
2034                            ca_certs=SIGNING_CA)
2035        self.addCleanup(s.close)
2036        self.assertEqual(0, s.connect_ex(self.server_addr))
2037        self.assertTrue(s.getpeercert())
2038
2039    def test_non_blocking_connect_ex(self):
2040        # Issue #11326: non-blocking connect_ex() should allow handshake
2041        # to proceed after the socket gets ready.
2042        s = test_wrap_socket(socket.socket(socket.AF_INET),
2043                            cert_reqs=ssl.CERT_REQUIRED,
2044                            ca_certs=SIGNING_CA,
2045                            do_handshake_on_connect=False)
2046        self.addCleanup(s.close)
2047        s.setblocking(False)
2048        rc = s.connect_ex(self.server_addr)
2049        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
2050        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
2051        # Wait for connect to finish
2052        select.select([], [s], [], 5.0)
2053        # Non-blocking handshake
2054        while True:
2055            try:
2056                s.do_handshake()
2057                break
2058            except ssl.SSLWantReadError:
2059                select.select([s], [], [], 5.0)
2060            except ssl.SSLWantWriteError:
2061                select.select([], [s], [], 5.0)
2062        # SSL established
2063        self.assertTrue(s.getpeercert())
2064
2065    def test_connect_with_context(self):
2066        # Same as test_connect, but with a separately created context
2067        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2068        ctx.check_hostname = False
2069        ctx.verify_mode = ssl.CERT_NONE
2070        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2071            s.connect(self.server_addr)
2072            self.assertEqual({}, s.getpeercert())
2073        # Same with a server hostname
2074        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2075                            server_hostname="dummy") as s:
2076            s.connect(self.server_addr)
2077        ctx.verify_mode = ssl.CERT_REQUIRED
2078        # This should succeed because we specify the root cert
2079        ctx.load_verify_locations(SIGNING_CA)
2080        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2081            s.connect(self.server_addr)
2082            cert = s.getpeercert()
2083            self.assertTrue(cert)
2084
2085    def test_connect_with_context_fail(self):
2086        # This should fail because we have no verification certs. Connection
2087        # failure crashes ThreadedEchoServer, so run this in an independent
2088        # test method.
2089        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2090        s = ctx.wrap_socket(
2091            socket.socket(socket.AF_INET),
2092            server_hostname=SIGNED_CERTFILE_HOSTNAME
2093        )
2094        self.addCleanup(s.close)
2095        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2096                                s.connect, self.server_addr)
2097
2098    def test_connect_capath(self):
2099        # Verify server certificates using the `capath` argument
2100        # NOTE: the subject hashing algorithm has been changed between
2101        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
2102        # contain both versions of each certificate (same content, different
2103        # filename) for this test to be portable across OpenSSL releases.
2104        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2105        ctx.load_verify_locations(capath=CAPATH)
2106        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2107                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2108            s.connect(self.server_addr)
2109            cert = s.getpeercert()
2110            self.assertTrue(cert)
2111
2112        # Same with a bytes `capath` argument
2113        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2114        ctx.load_verify_locations(capath=BYTES_CAPATH)
2115        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2116                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2117            s.connect(self.server_addr)
2118            cert = s.getpeercert()
2119            self.assertTrue(cert)
2120
2121    def test_connect_cadata(self):
2122        with open(SIGNING_CA) as f:
2123            pem = f.read()
2124        der = ssl.PEM_cert_to_DER_cert(pem)
2125        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2126        ctx.load_verify_locations(cadata=pem)
2127        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2128                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2129            s.connect(self.server_addr)
2130            cert = s.getpeercert()
2131            self.assertTrue(cert)
2132
2133        # same with DER
2134        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2135        ctx.load_verify_locations(cadata=der)
2136        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2137                             server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
2138            s.connect(self.server_addr)
2139            cert = s.getpeercert()
2140            self.assertTrue(cert)
2141
2142    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
2143    def test_makefile_close(self):
2144        # Issue #5238: creating a file-like object with makefile() shouldn't
2145        # delay closing the underlying "real socket" (here tested with its
2146        # file descriptor, hence skipping the test under Windows).
2147        ss = test_wrap_socket(socket.socket(socket.AF_INET))
2148        ss.connect(self.server_addr)
2149        fd = ss.fileno()
2150        f = ss.makefile()
2151        f.close()
2152        # The fd is still open
2153        os.read(fd, 0)
2154        # Closing the SSL socket should close the fd too
2155        ss.close()
2156        gc.collect()
2157        with self.assertRaises(OSError) as e:
2158            os.read(fd, 0)
2159        self.assertEqual(e.exception.errno, errno.EBADF)
2160
2161    def test_non_blocking_handshake(self):
2162        s = socket.socket(socket.AF_INET)
2163        s.connect(self.server_addr)
2164        s.setblocking(False)
2165        s = test_wrap_socket(s,
2166                            cert_reqs=ssl.CERT_NONE,
2167                            do_handshake_on_connect=False)
2168        self.addCleanup(s.close)
2169        count = 0
2170        while True:
2171            try:
2172                count += 1
2173                s.do_handshake()
2174                break
2175            except ssl.SSLWantReadError:
2176                select.select([s], [], [])
2177            except ssl.SSLWantWriteError:
2178                select.select([], [s], [])
2179        if support.verbose:
2180            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2181
2182    def test_get_server_certificate(self):
2183        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2184
2185    def test_get_server_certificate_sni(self):
2186        host, port = self.server_addr
2187        server_names = []
2188
2189        # We store servername_cb arguments to make sure they match the host
2190        def servername_cb(ssl_sock, server_name, initial_context):
2191            server_names.append(server_name)
2192        self.server_context.set_servername_callback(servername_cb)
2193
2194        pem = ssl.get_server_certificate((host, port))
2195        if not pem:
2196            self.fail("No server certificate on %s:%s!" % (host, port))
2197
2198        pem = ssl.get_server_certificate((host, port), ca_certs=SIGNING_CA)
2199        if not pem:
2200            self.fail("No server certificate on %s:%s!" % (host, port))
2201        if support.verbose:
2202            sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port, pem))
2203
2204        self.assertEqual(server_names, [host, host])
2205
2206    def test_get_server_certificate_fail(self):
2207        # Connection failure crashes ThreadedEchoServer, so run this in an
2208        # independent test method
2209        _test_get_server_certificate_fail(self, *self.server_addr)
2210
2211    def test_get_server_certificate_timeout(self):
2212        def servername_cb(ssl_sock, server_name, initial_context):
2213            time.sleep(0.2)
2214        self.server_context.set_servername_callback(servername_cb)
2215
2216        with self.assertRaises(socket.timeout):
2217            ssl.get_server_certificate(self.server_addr, ca_certs=SIGNING_CA,
2218                                       timeout=0.1)
2219
2220    def test_ciphers(self):
2221        with test_wrap_socket(socket.socket(socket.AF_INET),
2222                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2223            s.connect(self.server_addr)
2224        with test_wrap_socket(socket.socket(socket.AF_INET),
2225                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2226            s.connect(self.server_addr)
2227        # Error checking can happen at instantiation or when connecting
2228        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2229            with socket.socket(socket.AF_INET) as sock:
2230                s = test_wrap_socket(sock,
2231                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2232                s.connect(self.server_addr)
2233
2234    def test_get_ca_certs_capath(self):
2235        # capath certs are loaded on request
2236        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2237        ctx.load_verify_locations(capath=CAPATH)
2238        self.assertEqual(ctx.get_ca_certs(), [])
2239        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2240                             server_hostname='localhost') as s:
2241            s.connect(self.server_addr)
2242            cert = s.getpeercert()
2243            self.assertTrue(cert)
2244        self.assertEqual(len(ctx.get_ca_certs()), 1)
2245
2246    def test_context_setget(self):
2247        # Check that the context of a connected socket can be replaced.
2248        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2249        ctx1.load_verify_locations(capath=CAPATH)
2250        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2251        ctx2.load_verify_locations(capath=CAPATH)
2252        s = socket.socket(socket.AF_INET)
2253        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2254            ss.connect(self.server_addr)
2255            self.assertIs(ss.context, ctx1)
2256            self.assertIs(ss._sslobj.context, ctx1)
2257            ss.context = ctx2
2258            self.assertIs(ss.context, ctx2)
2259            self.assertIs(ss._sslobj.context, ctx2)
2260
2261    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2262        # A simple IO loop. Call func(*args) depending on the error we get
2263        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2264        timeout = kwargs.get('timeout', support.SHORT_TIMEOUT)
2265        deadline = time.monotonic() + timeout
2266        count = 0
2267        while True:
2268            if time.monotonic() > deadline:
2269                self.fail("timeout")
2270            errno = None
2271            count += 1
2272            try:
2273                ret = func(*args)
2274            except ssl.SSLError as e:
2275                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2276                                   ssl.SSL_ERROR_WANT_WRITE):
2277                    raise
2278                errno = e.errno
2279            # Get any data from the outgoing BIO irrespective of any error, and
2280            # send it to the socket.
2281            buf = outgoing.read()
2282            sock.sendall(buf)
2283            # If there's no error, we're done. For WANT_READ, we need to get
2284            # data from the socket and put it in the incoming BIO.
2285            if errno is None:
2286                break
2287            elif errno == ssl.SSL_ERROR_WANT_READ:
2288                buf = sock.recv(32768)
2289                if buf:
2290                    incoming.write(buf)
2291                else:
2292                    incoming.write_eof()
2293        if support.verbose:
2294            sys.stdout.write("Needed %d calls to complete %s().\n"
2295                             % (count, func.__name__))
2296        return ret
2297
2298    def test_bio_handshake(self):
2299        sock = socket.socket(socket.AF_INET)
2300        self.addCleanup(sock.close)
2301        sock.connect(self.server_addr)
2302        incoming = ssl.MemoryBIO()
2303        outgoing = ssl.MemoryBIO()
2304        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2305        self.assertTrue(ctx.check_hostname)
2306        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2307        ctx.load_verify_locations(SIGNING_CA)
2308        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2309                              SIGNED_CERTFILE_HOSTNAME)
2310        self.assertIs(sslobj._sslobj.owner, sslobj)
2311        self.assertIsNone(sslobj.cipher())
2312        self.assertIsNone(sslobj.version())
2313        self.assertIsNotNone(sslobj.shared_ciphers())
2314        self.assertRaises(ValueError, sslobj.getpeercert)
2315        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2316            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2317        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2318        self.assertTrue(sslobj.cipher())
2319        self.assertIsNotNone(sslobj.shared_ciphers())
2320        self.assertIsNotNone(sslobj.version())
2321        self.assertTrue(sslobj.getpeercert())
2322        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2323            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2324        try:
2325            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2326        except ssl.SSLSyscallError:
2327            # If the server shuts down the TCP connection without sending a
2328            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2329            pass
2330        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2331
2332    def test_bio_read_write_data(self):
2333        sock = socket.socket(socket.AF_INET)
2334        self.addCleanup(sock.close)
2335        sock.connect(self.server_addr)
2336        incoming = ssl.MemoryBIO()
2337        outgoing = ssl.MemoryBIO()
2338        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2339        ctx.check_hostname = False
2340        ctx.verify_mode = ssl.CERT_NONE
2341        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2342        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2343        req = b'FOO\n'
2344        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2345        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2346        self.assertEqual(buf, b'foo\n')
2347        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2348
2349
2350@support.requires_resource('network')
2351class NetworkedTests(unittest.TestCase):
2352
2353    def test_timeout_connect_ex(self):
2354        # Issue #12065: on a timeout, connect_ex() should return the original
2355        # errno (mimicking the behaviour of non-SSL sockets).
2356        with socket_helper.transient_internet(REMOTE_HOST):
2357            s = test_wrap_socket(socket.socket(socket.AF_INET),
2358                                cert_reqs=ssl.CERT_REQUIRED,
2359                                do_handshake_on_connect=False)
2360            self.addCleanup(s.close)
2361            s.settimeout(0.0000001)
2362            rc = s.connect_ex((REMOTE_HOST, 443))
2363            if rc == 0:
2364                self.skipTest("REMOTE_HOST responded too quickly")
2365            elif rc == errno.ENETUNREACH:
2366                self.skipTest("Network unreachable.")
2367            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2368
2369    @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'Needs IPv6')
2370    def test_get_server_certificate_ipv6(self):
2371        with socket_helper.transient_internet('ipv6.google.com'):
2372            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2373            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2374
2375
2376def _test_get_server_certificate(test, host, port, cert=None):
2377    pem = ssl.get_server_certificate((host, port))
2378    if not pem:
2379        test.fail("No server certificate on %s:%s!" % (host, port))
2380
2381    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2382    if not pem:
2383        test.fail("No server certificate on %s:%s!" % (host, port))
2384    if support.verbose:
2385        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2386
2387def _test_get_server_certificate_fail(test, host, port):
2388    try:
2389        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2390    except ssl.SSLError as x:
2391        #should fail
2392        if support.verbose:
2393            sys.stdout.write("%s\n" % x)
2394    else:
2395        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2396
2397
2398from test.ssl_servers import make_https_server
2399
2400class ThreadedEchoServer(threading.Thread):
2401
2402    class ConnectionHandler(threading.Thread):
2403
2404        """A mildly complicated class, because we want it to work both
2405        with and without the SSL wrapper around the socket connection, so
2406        that we can test the STARTTLS functionality."""
2407
2408        def __init__(self, server, connsock, addr):
2409            self.server = server
2410            self.running = False
2411            self.sock = connsock
2412            self.addr = addr
2413            self.sock.setblocking(True)
2414            self.sslconn = None
2415            threading.Thread.__init__(self)
2416            self.daemon = True
2417
2418        def wrap_conn(self):
2419            try:
2420                self.sslconn = self.server.context.wrap_socket(
2421                    self.sock, server_side=True)
2422                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2423            except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError) as e:
2424                # We treat ConnectionResetError as though it were an
2425                # SSLError - OpenSSL on Ubuntu abruptly closes the
2426                # connection when asked to use an unsupported protocol.
2427                #
2428                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2429                # tries to send session tickets after handshake.
2430                # https://github.com/openssl/openssl/issues/6342
2431                #
2432                # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL
2433                # tries to send session tickets after handshake when using WinSock.
2434                self.server.conn_errors.append(str(e))
2435                if self.server.chatty:
2436                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2437                self.running = False
2438                self.close()
2439                return False
2440            except (ssl.SSLError, OSError) as e:
2441                # OSError may occur with wrong protocols, e.g. both
2442                # sides use PROTOCOL_TLS_SERVER.
2443                #
2444                # XXX Various errors can have happened here, for example
2445                # a mismatching protocol version, an invalid certificate,
2446                # or a low-level bug. This should be made more discriminating.
2447                #
2448                # bpo-31323: Store the exception as string to prevent
2449                # a reference leak: server -> conn_errors -> exception
2450                # -> traceback -> self (ConnectionHandler) -> server
2451                self.server.conn_errors.append(str(e))
2452                if self.server.chatty:
2453                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2454
2455                # bpo-44229, bpo-43855, bpo-44237, and bpo-33450:
2456                # Ignore spurious EPROTOTYPE returned by write() on macOS.
2457                # See also http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
2458                if e.errno != errno.EPROTOTYPE and sys.platform != "darwin":
2459                    self.running = False
2460                    self.server.stop()
2461                    self.close()
2462                return False
2463            else:
2464                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2465                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2466                    cert = self.sslconn.getpeercert()
2467                    if support.verbose and self.server.chatty:
2468                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2469                    cert_binary = self.sslconn.getpeercert(True)
2470                    if support.verbose and self.server.chatty:
2471                        if cert_binary is None:
2472                            sys.stdout.write(" client did not provide a cert\n")
2473                        else:
2474                            sys.stdout.write(f" cert binary is {len(cert_binary)}b\n")
2475                cipher = self.sslconn.cipher()
2476                if support.verbose and self.server.chatty:
2477                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2478                return True
2479
2480        def read(self):
2481            if self.sslconn:
2482                return self.sslconn.read()
2483            else:
2484                return self.sock.recv(1024)
2485
2486        def write(self, bytes):
2487            if self.sslconn:
2488                return self.sslconn.write(bytes)
2489            else:
2490                return self.sock.send(bytes)
2491
2492        def close(self):
2493            if self.sslconn:
2494                self.sslconn.close()
2495            else:
2496                self.sock.close()
2497
2498        def run(self):
2499            self.running = True
2500            if not self.server.starttls_server:
2501                if not self.wrap_conn():
2502                    return
2503            while self.running:
2504                try:
2505                    msg = self.read()
2506                    stripped = msg.strip()
2507                    if not stripped:
2508                        # eof, so quit this handler
2509                        self.running = False
2510                        try:
2511                            self.sock = self.sslconn.unwrap()
2512                        except OSError:
2513                            # Many tests shut the TCP connection down
2514                            # without an SSL shutdown. This causes
2515                            # unwrap() to raise OSError with errno=0!
2516                            pass
2517                        else:
2518                            self.sslconn = None
2519                        self.close()
2520                    elif stripped == b'over':
2521                        if support.verbose and self.server.connectionchatty:
2522                            sys.stdout.write(" server: client closed connection\n")
2523                        self.close()
2524                        return
2525                    elif (self.server.starttls_server and
2526                          stripped == b'STARTTLS'):
2527                        if support.verbose and self.server.connectionchatty:
2528                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2529                        self.write(b"OK\n")
2530                        if not self.wrap_conn():
2531                            return
2532                    elif (self.server.starttls_server and self.sslconn
2533                          and stripped == b'ENDTLS'):
2534                        if support.verbose and self.server.connectionchatty:
2535                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2536                        self.write(b"OK\n")
2537                        self.sock = self.sslconn.unwrap()
2538                        self.sslconn = None
2539                        if support.verbose and self.server.connectionchatty:
2540                            sys.stdout.write(" server: connection is now unencrypted...\n")
2541                    elif stripped == b'CB tls-unique':
2542                        if support.verbose and self.server.connectionchatty:
2543                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2544                        data = self.sslconn.get_channel_binding("tls-unique")
2545                        self.write(repr(data).encode("us-ascii") + b"\n")
2546                    elif stripped == b'PHA':
2547                        if support.verbose and self.server.connectionchatty:
2548                            sys.stdout.write(" server: initiating post handshake auth\n")
2549                        try:
2550                            self.sslconn.verify_client_post_handshake()
2551                        except ssl.SSLError as e:
2552                            self.write(repr(e).encode("us-ascii") + b"\n")
2553                        else:
2554                            self.write(b"OK\n")
2555                    elif stripped == b'HASCERT':
2556                        if self.sslconn.getpeercert() is not None:
2557                            self.write(b'TRUE\n')
2558                        else:
2559                            self.write(b'FALSE\n')
2560                    elif stripped == b'GETCERT':
2561                        cert = self.sslconn.getpeercert()
2562                        self.write(repr(cert).encode("us-ascii") + b"\n")
2563                    elif stripped == b'VERIFIEDCHAIN':
2564                        certs = self.sslconn._sslobj.get_verified_chain()
2565                        self.write(len(certs).to_bytes(1, "big") + b"\n")
2566                    elif stripped == b'UNVERIFIEDCHAIN':
2567                        certs = self.sslconn._sslobj.get_unverified_chain()
2568                        self.write(len(certs).to_bytes(1, "big") + b"\n")
2569                    else:
2570                        if (support.verbose and
2571                            self.server.connectionchatty):
2572                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2573                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2574                                             % (msg, ctype, msg.lower(), ctype))
2575                        self.write(msg.lower())
2576                except OSError as e:
2577                    # handles SSLError and socket errors
2578                    if self.server.chatty and support.verbose:
2579                        if isinstance(e, ConnectionError):
2580                            # OpenSSL 1.1.1 sometimes raises
2581                            # ConnectionResetError when connection is not
2582                            # shut down gracefully.
2583                            print(
2584                                f" Connection reset by peer: {self.addr}"
2585                            )
2586                        else:
2587                            handle_error("Test server failure:\n")
2588                    try:
2589                        self.write(b"ERROR\n")
2590                    except OSError:
2591                        pass
2592                    self.close()
2593                    self.running = False
2594
2595                    # normally, we'd just stop here, but for the test
2596                    # harness, we want to stop the server
2597                    self.server.stop()
2598
2599    def __init__(self, certificate=None, ssl_version=None,
2600                 certreqs=None, cacerts=None,
2601                 chatty=True, connectionchatty=False, starttls_server=False,
2602                 alpn_protocols=None,
2603                 ciphers=None, context=None):
2604        if context:
2605            self.context = context
2606        else:
2607            self.context = ssl.SSLContext(ssl_version
2608                                          if ssl_version is not None
2609                                          else ssl.PROTOCOL_TLS_SERVER)
2610            self.context.verify_mode = (certreqs if certreqs is not None
2611                                        else ssl.CERT_NONE)
2612            if cacerts:
2613                self.context.load_verify_locations(cacerts)
2614            if certificate:
2615                self.context.load_cert_chain(certificate)
2616            if alpn_protocols:
2617                self.context.set_alpn_protocols(alpn_protocols)
2618            if ciphers:
2619                self.context.set_ciphers(ciphers)
2620        self.chatty = chatty
2621        self.connectionchatty = connectionchatty
2622        self.starttls_server = starttls_server
2623        self.sock = socket.socket()
2624        self.port = socket_helper.bind_port(self.sock)
2625        self.flag = None
2626        self.active = False
2627        self.selected_alpn_protocols = []
2628        self.shared_ciphers = []
2629        self.conn_errors = []
2630        threading.Thread.__init__(self)
2631        self.daemon = True
2632
2633    def __enter__(self):
2634        self.start(threading.Event())
2635        self.flag.wait()
2636        return self
2637
2638    def __exit__(self, *args):
2639        self.stop()
2640        self.join()
2641
2642    def start(self, flag=None):
2643        self.flag = flag
2644        threading.Thread.start(self)
2645
2646    def run(self):
2647        self.sock.settimeout(1.0)
2648        self.sock.listen(5)
2649        self.active = True
2650        if self.flag:
2651            # signal an event
2652            self.flag.set()
2653        while self.active:
2654            try:
2655                newconn, connaddr = self.sock.accept()
2656                if support.verbose and self.chatty:
2657                    sys.stdout.write(' server:  new connection from '
2658                                     + repr(connaddr) + '\n')
2659                handler = self.ConnectionHandler(self, newconn, connaddr)
2660                handler.start()
2661                handler.join()
2662            except TimeoutError as e:
2663                if support.verbose:
2664                    sys.stdout.write(f' connection timeout {e!r}\n')
2665            except KeyboardInterrupt:
2666                self.stop()
2667            except BaseException as e:
2668                if support.verbose and self.chatty:
2669                    sys.stdout.write(
2670                        ' connection handling failed: ' + repr(e) + '\n')
2671
2672        self.close()
2673
2674    def close(self):
2675        if self.sock is not None:
2676            self.sock.close()
2677            self.sock = None
2678
2679    def stop(self):
2680        self.active = False
2681
2682class AsyncoreEchoServer(threading.Thread):
2683
2684    # this one's based on asyncore.dispatcher
2685
2686    class EchoServer (asyncore.dispatcher):
2687
2688        class ConnectionHandler(asyncore.dispatcher_with_send):
2689
2690            def __init__(self, conn, certfile):
2691                self.socket = test_wrap_socket(conn, server_side=True,
2692                                              certfile=certfile,
2693                                              do_handshake_on_connect=False)
2694                asyncore.dispatcher_with_send.__init__(self, self.socket)
2695                self._ssl_accepting = True
2696                self._do_ssl_handshake()
2697
2698            def readable(self):
2699                if isinstance(self.socket, ssl.SSLSocket):
2700                    while self.socket.pending() > 0:
2701                        self.handle_read_event()
2702                return True
2703
2704            def _do_ssl_handshake(self):
2705                try:
2706                    self.socket.do_handshake()
2707                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2708                    return
2709                except ssl.SSLEOFError:
2710                    return self.handle_close()
2711                except ssl.SSLError:
2712                    raise
2713                except OSError as err:
2714                    if err.args[0] == errno.ECONNABORTED:
2715                        return self.handle_close()
2716                else:
2717                    self._ssl_accepting = False
2718
2719            def handle_read(self):
2720                if self._ssl_accepting:
2721                    self._do_ssl_handshake()
2722                else:
2723                    data = self.recv(1024)
2724                    if support.verbose:
2725                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2726                    if not data:
2727                        self.close()
2728                    else:
2729                        self.send(data.lower())
2730
2731            def handle_close(self):
2732                self.close()
2733                if support.verbose:
2734                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2735
2736            def handle_error(self):
2737                raise
2738
2739        def __init__(self, certfile):
2740            self.certfile = certfile
2741            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2742            self.port = socket_helper.bind_port(sock, '')
2743            asyncore.dispatcher.__init__(self, sock)
2744            self.listen(5)
2745
2746        def handle_accepted(self, sock_obj, addr):
2747            if support.verbose:
2748                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2749            self.ConnectionHandler(sock_obj, self.certfile)
2750
2751        def handle_error(self):
2752            raise
2753
2754    def __init__(self, certfile):
2755        self.flag = None
2756        self.active = False
2757        self.server = self.EchoServer(certfile)
2758        self.port = self.server.port
2759        threading.Thread.__init__(self)
2760        self.daemon = True
2761
2762    def __str__(self):
2763        return "<%s %s>" % (self.__class__.__name__, self.server)
2764
2765    def __enter__(self):
2766        self.start(threading.Event())
2767        self.flag.wait()
2768        return self
2769
2770    def __exit__(self, *args):
2771        if support.verbose:
2772            sys.stdout.write(" cleanup: stopping server.\n")
2773        self.stop()
2774        if support.verbose:
2775            sys.stdout.write(" cleanup: joining server thread.\n")
2776        self.join()
2777        if support.verbose:
2778            sys.stdout.write(" cleanup: successfully joined.\n")
2779        # make sure that ConnectionHandler is removed from socket_map
2780        asyncore.close_all(ignore_all=True)
2781
2782    def start (self, flag=None):
2783        self.flag = flag
2784        threading.Thread.start(self)
2785
2786    def run(self):
2787        self.active = True
2788        if self.flag:
2789            self.flag.set()
2790        while self.active:
2791            try:
2792                asyncore.loop(1)
2793            except:
2794                pass
2795
2796    def stop(self):
2797        self.active = False
2798        self.server.close()
2799
2800def server_params_test(client_context, server_context, indata=b"FOO\n",
2801                       chatty=True, connectionchatty=False, sni_name=None,
2802                       session=None):
2803    """
2804    Launch a server, connect a client to it and try various reads
2805    and writes.
2806    """
2807    stats = {}
2808    server = ThreadedEchoServer(context=server_context,
2809                                chatty=chatty,
2810                                connectionchatty=False)
2811    with server:
2812        with client_context.wrap_socket(socket.socket(),
2813                server_hostname=sni_name, session=session) as s:
2814            s.connect((HOST, server.port))
2815            for arg in [indata, bytearray(indata), memoryview(indata)]:
2816                if connectionchatty:
2817                    if support.verbose:
2818                        sys.stdout.write(
2819                            " client:  sending %r...\n" % indata)
2820                s.write(arg)
2821                outdata = s.read()
2822                if connectionchatty:
2823                    if support.verbose:
2824                        sys.stdout.write(" client:  read %r\n" % outdata)
2825                if outdata != indata.lower():
2826                    raise AssertionError(
2827                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2828                        % (outdata[:20], len(outdata),
2829                           indata[:20].lower(), len(indata)))
2830            s.write(b"over\n")
2831            if connectionchatty:
2832                if support.verbose:
2833                    sys.stdout.write(" client:  closing connection.\n")
2834            stats.update({
2835                'compression': s.compression(),
2836                'cipher': s.cipher(),
2837                'peercert': s.getpeercert(),
2838                'client_alpn_protocol': s.selected_alpn_protocol(),
2839                'version': s.version(),
2840                'session_reused': s.session_reused,
2841                'session': s.session,
2842            })
2843            s.close()
2844        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2845        stats['server_shared_ciphers'] = server.shared_ciphers
2846    return stats
2847
2848def try_protocol_combo(server_protocol, client_protocol, expect_success,
2849                       certsreqs=None, server_options=0, client_options=0):
2850    """
2851    Try to SSL-connect using *client_protocol* to *server_protocol*.
2852    If *expect_success* is true, assert that the connection succeeds,
2853    if it's false, assert that the connection fails.
2854    Also, if *expect_success* is a string, assert that it is the protocol
2855    version actually used by the connection.
2856    """
2857    if certsreqs is None:
2858        certsreqs = ssl.CERT_NONE
2859    certtype = {
2860        ssl.CERT_NONE: "CERT_NONE",
2861        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2862        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2863    }[certsreqs]
2864    if support.verbose:
2865        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2866        sys.stdout.write(formatstr %
2867                         (ssl.get_protocol_name(client_protocol),
2868                          ssl.get_protocol_name(server_protocol),
2869                          certtype))
2870
2871    with warnings_helper.check_warnings():
2872        # ignore Deprecation warnings
2873        client_context = ssl.SSLContext(client_protocol)
2874        client_context.options |= client_options
2875        server_context = ssl.SSLContext(server_protocol)
2876        server_context.options |= server_options
2877
2878    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2879    if (min_version is not None
2880        # SSLContext.minimum_version is only available on recent OpenSSL
2881        # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2882        and hasattr(server_context, 'minimum_version')
2883        and server_protocol == ssl.PROTOCOL_TLS
2884        and server_context.minimum_version > min_version
2885    ):
2886        # If OpenSSL configuration is strict and requires more recent TLS
2887        # version, we have to change the minimum to test old TLS versions.
2888        with warnings_helper.check_warnings():
2889            server_context.minimum_version = min_version
2890
2891    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2892    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2893    # starting from OpenSSL 1.0.0 (see issue #8322).
2894    if client_context.protocol == ssl.PROTOCOL_TLS:
2895        client_context.set_ciphers("ALL")
2896
2897    seclevel_workaround(server_context, client_context)
2898
2899    for ctx in (client_context, server_context):
2900        ctx.verify_mode = certsreqs
2901        ctx.load_cert_chain(SIGNED_CERTFILE)
2902        ctx.load_verify_locations(SIGNING_CA)
2903    try:
2904        stats = server_params_test(client_context, server_context,
2905                                   chatty=False, connectionchatty=False)
2906    # Protocol mismatch can result in either an SSLError, or a
2907    # "Connection reset by peer" error.
2908    except ssl.SSLError:
2909        if expect_success:
2910            raise
2911    except OSError as e:
2912        if expect_success or e.errno != errno.ECONNRESET:
2913            raise
2914    else:
2915        if not expect_success:
2916            raise AssertionError(
2917                "Client protocol %s succeeded with server protocol %s!"
2918                % (ssl.get_protocol_name(client_protocol),
2919                   ssl.get_protocol_name(server_protocol)))
2920        elif (expect_success is not True
2921              and expect_success != stats['version']):
2922            raise AssertionError("version mismatch: expected %r, got %r"
2923                                 % (expect_success, stats['version']))
2924
2925
2926class ThreadedTests(unittest.TestCase):
2927
2928    def test_echo(self):
2929        """Basic test of an SSL client connecting to a server"""
2930        if support.verbose:
2931            sys.stdout.write("\n")
2932
2933        client_context, server_context, hostname = testing_context()
2934
2935        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2936            server_params_test(client_context=client_context,
2937                               server_context=server_context,
2938                               chatty=True, connectionchatty=True,
2939                               sni_name=hostname)
2940
2941        client_context.check_hostname = False
2942        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2943            with self.assertRaises(ssl.SSLError) as e:
2944                server_params_test(client_context=server_context,
2945                                   server_context=client_context,
2946                                   chatty=True, connectionchatty=True,
2947                                   sni_name=hostname)
2948            self.assertIn(
2949                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2950                str(e.exception)
2951            )
2952
2953        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2954            with self.assertRaises(ssl.SSLError) as e:
2955                server_params_test(client_context=server_context,
2956                                   server_context=server_context,
2957                                   chatty=True, connectionchatty=True)
2958            self.assertIn(
2959                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2960                str(e.exception)
2961            )
2962
2963        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2964            with self.assertRaises(ssl.SSLError) as e:
2965                server_params_test(client_context=server_context,
2966                                   server_context=client_context,
2967                                   chatty=True, connectionchatty=True)
2968            self.assertIn(
2969                'Cannot create a client socket with a PROTOCOL_TLS_SERVER context',
2970                str(e.exception))
2971
2972    def test_getpeercert(self):
2973        if support.verbose:
2974            sys.stdout.write("\n")
2975
2976        client_context, server_context, hostname = testing_context()
2977        server = ThreadedEchoServer(context=server_context, chatty=False)
2978        with server:
2979            with client_context.wrap_socket(socket.socket(),
2980                                            do_handshake_on_connect=False,
2981                                            server_hostname=hostname) as s:
2982                s.connect((HOST, server.port))
2983                # getpeercert() raise ValueError while the handshake isn't
2984                # done.
2985                with self.assertRaises(ValueError):
2986                    s.getpeercert()
2987                s.do_handshake()
2988                cert = s.getpeercert()
2989                self.assertTrue(cert, "Can't get peer certificate.")
2990                cipher = s.cipher()
2991                if support.verbose:
2992                    sys.stdout.write(pprint.pformat(cert) + '\n')
2993                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2994                if 'subject' not in cert:
2995                    self.fail("No subject field in certificate: %s." %
2996                              pprint.pformat(cert))
2997                if ((('organizationName', 'Python Software Foundation'),)
2998                    not in cert['subject']):
2999                    self.fail(
3000                        "Missing or invalid 'organizationName' field in certificate subject; "
3001                        "should be 'Python Software Foundation'.")
3002                self.assertIn('notBefore', cert)
3003                self.assertIn('notAfter', cert)
3004                before = ssl.cert_time_to_seconds(cert['notBefore'])
3005                after = ssl.cert_time_to_seconds(cert['notAfter'])
3006                self.assertLess(before, after)
3007
3008    def test_crl_check(self):
3009        if support.verbose:
3010            sys.stdout.write("\n")
3011
3012        client_context, server_context, hostname = testing_context()
3013
3014        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
3015        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
3016
3017        # VERIFY_DEFAULT should pass
3018        server = ThreadedEchoServer(context=server_context, chatty=True)
3019        with server:
3020            with client_context.wrap_socket(socket.socket(),
3021                                            server_hostname=hostname) as s:
3022                s.connect((HOST, server.port))
3023                cert = s.getpeercert()
3024                self.assertTrue(cert, "Can't get peer certificate.")
3025
3026        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
3027        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
3028
3029        server = ThreadedEchoServer(context=server_context, chatty=True)
3030        with server:
3031            with client_context.wrap_socket(socket.socket(),
3032                                            server_hostname=hostname) as s:
3033                with self.assertRaisesRegex(ssl.SSLError,
3034                                            "certificate verify failed"):
3035                    s.connect((HOST, server.port))
3036
3037        # now load a CRL file. The CRL file is signed by the CA.
3038        client_context.load_verify_locations(CRLFILE)
3039
3040        server = ThreadedEchoServer(context=server_context, chatty=True)
3041        with server:
3042            with client_context.wrap_socket(socket.socket(),
3043                                            server_hostname=hostname) as s:
3044                s.connect((HOST, server.port))
3045                cert = s.getpeercert()
3046                self.assertTrue(cert, "Can't get peer certificate.")
3047
3048    def test_check_hostname(self):
3049        if support.verbose:
3050            sys.stdout.write("\n")
3051
3052        client_context, server_context, hostname = testing_context()
3053
3054        # correct hostname should verify
3055        server = ThreadedEchoServer(context=server_context, chatty=True)
3056        with server:
3057            with client_context.wrap_socket(socket.socket(),
3058                                            server_hostname=hostname) as s:
3059                s.connect((HOST, server.port))
3060                cert = s.getpeercert()
3061                self.assertTrue(cert, "Can't get peer certificate.")
3062
3063        # incorrect hostname should raise an exception
3064        server = ThreadedEchoServer(context=server_context, chatty=True)
3065        with server:
3066            with client_context.wrap_socket(socket.socket(),
3067                                            server_hostname="invalid") as s:
3068                with self.assertRaisesRegex(
3069                        ssl.CertificateError,
3070                        "Hostname mismatch, certificate is not valid for 'invalid'."):
3071                    s.connect((HOST, server.port))
3072
3073        # missing server_hostname arg should cause an exception, too
3074        server = ThreadedEchoServer(context=server_context, chatty=True)
3075        with server:
3076            with socket.socket() as s:
3077                with self.assertRaisesRegex(ValueError,
3078                                            "check_hostname requires server_hostname"):
3079                    client_context.wrap_socket(s)
3080
3081    @unittest.skipUnless(
3082        ssl.HAS_NEVER_CHECK_COMMON_NAME, "test requires hostname_checks_common_name"
3083    )
3084    def test_hostname_checks_common_name(self):
3085        client_context, server_context, hostname = testing_context()
3086        assert client_context.hostname_checks_common_name
3087        client_context.hostname_checks_common_name = False
3088
3089        # default cert has a SAN
3090        server = ThreadedEchoServer(context=server_context, chatty=True)
3091        with server:
3092            with client_context.wrap_socket(socket.socket(),
3093                                            server_hostname=hostname) as s:
3094                s.connect((HOST, server.port))
3095
3096        client_context, server_context, hostname = testing_context(NOSANFILE)
3097        client_context.hostname_checks_common_name = False
3098        server = ThreadedEchoServer(context=server_context, chatty=True)
3099        with server:
3100            with client_context.wrap_socket(socket.socket(),
3101                                            server_hostname=hostname) as s:
3102                with self.assertRaises(ssl.SSLCertVerificationError):
3103                    s.connect((HOST, server.port))
3104
3105    def test_ecc_cert(self):
3106        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3107        client_context.load_verify_locations(SIGNING_CA)
3108        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3109        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3110
3111        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3112        # load ECC cert
3113        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3114
3115        # correct hostname should verify
3116        server = ThreadedEchoServer(context=server_context, chatty=True)
3117        with server:
3118            with client_context.wrap_socket(socket.socket(),
3119                                            server_hostname=hostname) as s:
3120                s.connect((HOST, server.port))
3121                cert = s.getpeercert()
3122                self.assertTrue(cert, "Can't get peer certificate.")
3123                cipher = s.cipher()[0].split('-')
3124                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3125
3126    def test_dual_rsa_ecc(self):
3127        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3128        client_context.load_verify_locations(SIGNING_CA)
3129        # TODO: fix TLSv1.3 once SSLContext can restrict signature
3130        #       algorithms.
3131        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3132        # only ECDSA certs
3133        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3134        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3135
3136        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3137        # load ECC and RSA key/cert pairs
3138        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3139        server_context.load_cert_chain(SIGNED_CERTFILE)
3140
3141        # correct hostname should verify
3142        server = ThreadedEchoServer(context=server_context, chatty=True)
3143        with server:
3144            with client_context.wrap_socket(socket.socket(),
3145                                            server_hostname=hostname) as s:
3146                s.connect((HOST, server.port))
3147                cert = s.getpeercert()
3148                self.assertTrue(cert, "Can't get peer certificate.")
3149                cipher = s.cipher()[0].split('-')
3150                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3151
3152    def test_check_hostname_idn(self):
3153        if support.verbose:
3154            sys.stdout.write("\n")
3155
3156        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3157        server_context.load_cert_chain(IDNSANSFILE)
3158
3159        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3160        context.verify_mode = ssl.CERT_REQUIRED
3161        context.check_hostname = True
3162        context.load_verify_locations(SIGNING_CA)
3163
3164        # correct hostname should verify, when specified in several
3165        # different ways
3166        idn_hostnames = [
3167            ('könig.idn.pythontest.net',
3168             'xn--knig-5qa.idn.pythontest.net'),
3169            ('xn--knig-5qa.idn.pythontest.net',
3170             'xn--knig-5qa.idn.pythontest.net'),
3171            (b'xn--knig-5qa.idn.pythontest.net',
3172             'xn--knig-5qa.idn.pythontest.net'),
3173
3174            ('königsgäßchen.idna2003.pythontest.net',
3175             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3176            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3177             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3178            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3179             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3180
3181            # ('königsgäßchen.idna2008.pythontest.net',
3182            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3183            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3184             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3185            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3186             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3187
3188        ]
3189        for server_hostname, expected_hostname in idn_hostnames:
3190            server = ThreadedEchoServer(context=server_context, chatty=True)
3191            with server:
3192                with context.wrap_socket(socket.socket(),
3193                                         server_hostname=server_hostname) as s:
3194                    self.assertEqual(s.server_hostname, expected_hostname)
3195                    s.connect((HOST, server.port))
3196                    cert = s.getpeercert()
3197                    self.assertEqual(s.server_hostname, expected_hostname)
3198                    self.assertTrue(cert, "Can't get peer certificate.")
3199
3200        # incorrect hostname should raise an exception
3201        server = ThreadedEchoServer(context=server_context, chatty=True)
3202        with server:
3203            with context.wrap_socket(socket.socket(),
3204                                     server_hostname="python.example.org") as s:
3205                with self.assertRaises(ssl.CertificateError):
3206                    s.connect((HOST, server.port))
3207
3208    def test_wrong_cert_tls12(self):
3209        """Connecting when the server rejects the client's certificate
3210
3211        Launch a server with CERT_REQUIRED, and check that trying to
3212        connect to it with a wrong client certificate fails.
3213        """
3214        client_context, server_context, hostname = testing_context()
3215        # load client cert that is not signed by trusted CA
3216        client_context.load_cert_chain(CERTFILE)
3217        # require TLS client authentication
3218        server_context.verify_mode = ssl.CERT_REQUIRED
3219        # TLS 1.3 has different handshake
3220        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3221
3222        server = ThreadedEchoServer(
3223            context=server_context, chatty=True, connectionchatty=True,
3224        )
3225
3226        with server, \
3227                client_context.wrap_socket(socket.socket(),
3228                                           server_hostname=hostname) as s:
3229            try:
3230                # Expect either an SSL error about the server rejecting
3231                # the connection, or a low-level connection reset (which
3232                # sometimes happens on Windows)
3233                s.connect((HOST, server.port))
3234            except ssl.SSLError as e:
3235                if support.verbose:
3236                    sys.stdout.write("\nSSLError is %r\n" % e)
3237            except OSError as e:
3238                if e.errno != errno.ECONNRESET:
3239                    raise
3240                if support.verbose:
3241                    sys.stdout.write("\nsocket.error is %r\n" % e)
3242            else:
3243                self.fail("Use of invalid cert should have failed!")
3244
3245    @requires_tls_version('TLSv1_3')
3246    def test_wrong_cert_tls13(self):
3247        client_context, server_context, hostname = testing_context()
3248        # load client cert that is not signed by trusted CA
3249        client_context.load_cert_chain(CERTFILE)
3250        server_context.verify_mode = ssl.CERT_REQUIRED
3251        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3252        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3253
3254        server = ThreadedEchoServer(
3255            context=server_context, chatty=True, connectionchatty=True,
3256        )
3257        with server, \
3258             client_context.wrap_socket(socket.socket(),
3259                                        server_hostname=hostname,
3260                                        suppress_ragged_eofs=False) as s:
3261            # TLS 1.3 perform client cert exchange after handshake
3262            s.connect((HOST, server.port))
3263            try:
3264                s.write(b'data')
3265                s.read(1000)
3266                s.write(b'should have failed already')
3267                s.read(1000)
3268            except ssl.SSLError as e:
3269                if support.verbose:
3270                    sys.stdout.write("\nSSLError is %r\n" % e)
3271            except OSError as e:
3272                if e.errno != errno.ECONNRESET:
3273                    raise
3274                if support.verbose:
3275                    sys.stdout.write("\nsocket.error is %r\n" % e)
3276            else:
3277                self.fail("Use of invalid cert should have failed!")
3278
3279    def test_rude_shutdown(self):
3280        """A brutal shutdown of an SSL server should raise an OSError
3281        in the client when attempting handshake.
3282        """
3283        listener_ready = threading.Event()
3284        listener_gone = threading.Event()
3285
3286        s = socket.socket()
3287        port = socket_helper.bind_port(s, HOST)
3288
3289        # `listener` runs in a thread.  It sits in an accept() until
3290        # the main thread connects.  Then it rudely closes the socket,
3291        # and sets Event `listener_gone` to let the main thread know
3292        # the socket is gone.
3293        def listener():
3294            s.listen()
3295            listener_ready.set()
3296            newsock, addr = s.accept()
3297            newsock.close()
3298            s.close()
3299            listener_gone.set()
3300
3301        def connector():
3302            listener_ready.wait()
3303            with socket.socket() as c:
3304                c.connect((HOST, port))
3305                listener_gone.wait()
3306                try:
3307                    ssl_sock = test_wrap_socket(c)
3308                except OSError:
3309                    pass
3310                else:
3311                    self.fail('connecting to closed SSL socket should have failed')
3312
3313        t = threading.Thread(target=listener)
3314        t.start()
3315        try:
3316            connector()
3317        finally:
3318            t.join()
3319
3320    def test_ssl_cert_verify_error(self):
3321        if support.verbose:
3322            sys.stdout.write("\n")
3323
3324        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3325        server_context.load_cert_chain(SIGNED_CERTFILE)
3326
3327        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3328
3329        server = ThreadedEchoServer(context=server_context, chatty=True)
3330        with server:
3331            with context.wrap_socket(socket.socket(),
3332                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3333                try:
3334                    s.connect((HOST, server.port))
3335                except ssl.SSLError as e:
3336                    msg = 'unable to get local issuer certificate'
3337                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3338                    self.assertEqual(e.verify_code, 20)
3339                    self.assertEqual(e.verify_message, msg)
3340                    self.assertIn(msg, repr(e))
3341                    self.assertIn('certificate verify failed', repr(e))
3342
3343    @requires_tls_version('SSLv2')
3344    def test_protocol_sslv2(self):
3345        """Connecting to an SSLv2 server with various client options"""
3346        if support.verbose:
3347            sys.stdout.write("\n")
3348        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3349        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3350        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3351        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3352        if has_tls_version('SSLv3'):
3353            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3354        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3355        # SSLv23 client with specific SSL options
3356        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3357                           client_options=ssl.OP_NO_SSLv3)
3358        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3359                           client_options=ssl.OP_NO_TLSv1)
3360
3361    def test_PROTOCOL_TLS(self):
3362        """Connecting to an SSLv23 server with various client options"""
3363        if support.verbose:
3364            sys.stdout.write("\n")
3365        if has_tls_version('SSLv2'):
3366            try:
3367                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3368            except OSError as x:
3369                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3370                if support.verbose:
3371                    sys.stdout.write(
3372                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3373                        % str(x))
3374        if has_tls_version('SSLv3'):
3375            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3376        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3377        if has_tls_version('TLSv1'):
3378            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3379
3380        if has_tls_version('SSLv3'):
3381            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3382        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3383        if has_tls_version('TLSv1'):
3384            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3385
3386        if has_tls_version('SSLv3'):
3387            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3388        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3389        if has_tls_version('TLSv1'):
3390            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3391
3392        # Server with specific SSL options
3393        if has_tls_version('SSLv3'):
3394            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3395                           server_options=ssl.OP_NO_SSLv3)
3396        # Will choose TLSv1
3397        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3398                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3399        if has_tls_version('TLSv1'):
3400            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3401                               server_options=ssl.OP_NO_TLSv1)
3402
3403    @requires_tls_version('SSLv3')
3404    def test_protocol_sslv3(self):
3405        """Connecting to an SSLv3 server with various client options"""
3406        if support.verbose:
3407            sys.stdout.write("\n")
3408        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3409        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3410        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3411        if has_tls_version('SSLv2'):
3412            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3413        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3414                           client_options=ssl.OP_NO_SSLv3)
3415        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3416
3417    @requires_tls_version('TLSv1')
3418    def test_protocol_tlsv1(self):
3419        """Connecting to a TLSv1 server with various client options"""
3420        if support.verbose:
3421            sys.stdout.write("\n")
3422        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3423        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3424        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3425        if has_tls_version('SSLv2'):
3426            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3427        if has_tls_version('SSLv3'):
3428            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3429        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3430                           client_options=ssl.OP_NO_TLSv1)
3431
3432    @requires_tls_version('TLSv1_1')
3433    def test_protocol_tlsv1_1(self):
3434        """Connecting to a TLSv1.1 server with various client options.
3435           Testing against older TLS versions."""
3436        if support.verbose:
3437            sys.stdout.write("\n")
3438        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3439        if has_tls_version('SSLv2'):
3440            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3441        if has_tls_version('SSLv3'):
3442            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3443        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3444                           client_options=ssl.OP_NO_TLSv1_1)
3445
3446        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3447        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3448        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3449
3450    @requires_tls_version('TLSv1_2')
3451    def test_protocol_tlsv1_2(self):
3452        """Connecting to a TLSv1.2 server with various client options.
3453           Testing against older TLS versions."""
3454        if support.verbose:
3455            sys.stdout.write("\n")
3456        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3457                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3458                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3459        if has_tls_version('SSLv2'):
3460            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3461        if has_tls_version('SSLv3'):
3462            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3463        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3464                           client_options=ssl.OP_NO_TLSv1_2)
3465
3466        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3467        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3468        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3469        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3470        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3471
3472    def test_starttls(self):
3473        """Switching from clear text to encrypted and back again."""
3474        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3475
3476        server = ThreadedEchoServer(CERTFILE,
3477                                    starttls_server=True,
3478                                    chatty=True,
3479                                    connectionchatty=True)
3480        wrapped = False
3481        with server:
3482            s = socket.socket()
3483            s.setblocking(True)
3484            s.connect((HOST, server.port))
3485            if support.verbose:
3486                sys.stdout.write("\n")
3487            for indata in msgs:
3488                if support.verbose:
3489                    sys.stdout.write(
3490                        " client:  sending %r...\n" % indata)
3491                if wrapped:
3492                    conn.write(indata)
3493                    outdata = conn.read()
3494                else:
3495                    s.send(indata)
3496                    outdata = s.recv(1024)
3497                msg = outdata.strip().lower()
3498                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3499                    # STARTTLS ok, switch to secure mode
3500                    if support.verbose:
3501                        sys.stdout.write(
3502                            " client:  read %r from server, starting TLS...\n"
3503                            % msg)
3504                    conn = test_wrap_socket(s)
3505                    wrapped = True
3506                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3507                    # ENDTLS ok, switch back to clear text
3508                    if support.verbose:
3509                        sys.stdout.write(
3510                            " client:  read %r from server, ending TLS...\n"
3511                            % msg)
3512                    s = conn.unwrap()
3513                    wrapped = False
3514                else:
3515                    if support.verbose:
3516                        sys.stdout.write(
3517                            " client:  read %r from server\n" % msg)
3518            if support.verbose:
3519                sys.stdout.write(" client:  closing connection.\n")
3520            if wrapped:
3521                conn.write(b"over\n")
3522            else:
3523                s.send(b"over\n")
3524            if wrapped:
3525                conn.close()
3526            else:
3527                s.close()
3528
3529    def test_socketserver(self):
3530        """Using socketserver to create and manage SSL connections."""
3531        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3532        # try to connect
3533        if support.verbose:
3534            sys.stdout.write('\n')
3535        with open(CERTFILE, 'rb') as f:
3536            d1 = f.read()
3537        d2 = ''
3538        # now fetch the same data from the HTTPS server
3539        url = 'https://localhost:%d/%s' % (
3540            server.port, os.path.split(CERTFILE)[1])
3541        context = ssl.create_default_context(cafile=SIGNING_CA)
3542        f = urllib.request.urlopen(url, context=context)
3543        try:
3544            dlen = f.info().get("content-length")
3545            if dlen and (int(dlen) > 0):
3546                d2 = f.read(int(dlen))
3547                if support.verbose:
3548                    sys.stdout.write(
3549                        " client: read %d bytes from remote server '%s'\n"
3550                        % (len(d2), server))
3551        finally:
3552            f.close()
3553        self.assertEqual(d1, d2)
3554
3555    def test_asyncore_server(self):
3556        """Check the example asyncore integration."""
3557        if support.verbose:
3558            sys.stdout.write("\n")
3559
3560        indata = b"FOO\n"
3561        server = AsyncoreEchoServer(CERTFILE)
3562        with server:
3563            s = test_wrap_socket(socket.socket())
3564            s.connect(('127.0.0.1', server.port))
3565            if support.verbose:
3566                sys.stdout.write(
3567                    " client:  sending %r...\n" % indata)
3568            s.write(indata)
3569            outdata = s.read()
3570            if support.verbose:
3571                sys.stdout.write(" client:  read %r\n" % outdata)
3572            if outdata != indata.lower():
3573                self.fail(
3574                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3575                    % (outdata[:20], len(outdata),
3576                       indata[:20].lower(), len(indata)))
3577            s.write(b"over\n")
3578            if support.verbose:
3579                sys.stdout.write(" client:  closing connection.\n")
3580            s.close()
3581            if support.verbose:
3582                sys.stdout.write(" client:  connection closed.\n")
3583
3584    def test_recv_send(self):
3585        """Test recv(), send() and friends."""
3586        if support.verbose:
3587            sys.stdout.write("\n")
3588
3589        server = ThreadedEchoServer(CERTFILE,
3590                                    certreqs=ssl.CERT_NONE,
3591                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3592                                    cacerts=CERTFILE,
3593                                    chatty=True,
3594                                    connectionchatty=False)
3595        with server:
3596            s = test_wrap_socket(socket.socket(),
3597                                server_side=False,
3598                                certfile=CERTFILE,
3599                                ca_certs=CERTFILE,
3600                                cert_reqs=ssl.CERT_NONE)
3601            s.connect((HOST, server.port))
3602            # helper methods for standardising recv* method signatures
3603            def _recv_into():
3604                b = bytearray(b"\0"*100)
3605                count = s.recv_into(b)
3606                return b[:count]
3607
3608            def _recvfrom_into():
3609                b = bytearray(b"\0"*100)
3610                count, addr = s.recvfrom_into(b)
3611                return b[:count]
3612
3613            # (name, method, expect success?, *args, return value func)
3614            send_methods = [
3615                ('send', s.send, True, [], len),
3616                ('sendto', s.sendto, False, ["some.address"], len),
3617                ('sendall', s.sendall, True, [], lambda x: None),
3618            ]
3619            # (name, method, whether to expect success, *args)
3620            recv_methods = [
3621                ('recv', s.recv, True, []),
3622                ('recvfrom', s.recvfrom, False, ["some.address"]),
3623                ('recv_into', _recv_into, True, []),
3624                ('recvfrom_into', _recvfrom_into, False, []),
3625            ]
3626            data_prefix = "PREFIX_"
3627
3628            for (meth_name, send_meth, expect_success, args,
3629                    ret_val_meth) in send_methods:
3630                indata = (data_prefix + meth_name).encode('ascii')
3631                try:
3632                    ret = send_meth(indata, *args)
3633                    msg = "sending with {}".format(meth_name)
3634                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3635                    outdata = s.read()
3636                    if outdata != indata.lower():
3637                        self.fail(
3638                            "While sending with <<{name:s}>> bad data "
3639                            "<<{outdata:r}>> ({nout:d}) received; "
3640                            "expected <<{indata:r}>> ({nin:d})\n".format(
3641                                name=meth_name, outdata=outdata[:20],
3642                                nout=len(outdata),
3643                                indata=indata[:20], nin=len(indata)
3644                            )
3645                        )
3646                except ValueError as e:
3647                    if expect_success:
3648                        self.fail(
3649                            "Failed to send with method <<{name:s}>>; "
3650                            "expected to succeed.\n".format(name=meth_name)
3651                        )
3652                    if not str(e).startswith(meth_name):
3653                        self.fail(
3654                            "Method <<{name:s}>> failed with unexpected "
3655                            "exception message: {exp:s}\n".format(
3656                                name=meth_name, exp=e
3657                            )
3658                        )
3659
3660            for meth_name, recv_meth, expect_success, args in recv_methods:
3661                indata = (data_prefix + meth_name).encode('ascii')
3662                try:
3663                    s.send(indata)
3664                    outdata = recv_meth(*args)
3665                    if outdata != indata.lower():
3666                        self.fail(
3667                            "While receiving with <<{name:s}>> bad data "
3668                            "<<{outdata:r}>> ({nout:d}) received; "
3669                            "expected <<{indata:r}>> ({nin:d})\n".format(
3670                                name=meth_name, outdata=outdata[:20],
3671                                nout=len(outdata),
3672                                indata=indata[:20], nin=len(indata)
3673                            )
3674                        )
3675                except ValueError as e:
3676                    if expect_success:
3677                        self.fail(
3678                            "Failed to receive with method <<{name:s}>>; "
3679                            "expected to succeed.\n".format(name=meth_name)
3680                        )
3681                    if not str(e).startswith(meth_name):
3682                        self.fail(
3683                            "Method <<{name:s}>> failed with unexpected "
3684                            "exception message: {exp:s}\n".format(
3685                                name=meth_name, exp=e
3686                            )
3687                        )
3688                    # consume data
3689                    s.read()
3690
3691            # read(-1, buffer) is supported, even though read(-1) is not
3692            data = b"data"
3693            s.send(data)
3694            buffer = bytearray(len(data))
3695            self.assertEqual(s.read(-1, buffer), len(data))
3696            self.assertEqual(buffer, data)
3697
3698            # sendall accepts bytes-like objects
3699            if ctypes is not None:
3700                ubyte = ctypes.c_ubyte * len(data)
3701                byteslike = ubyte.from_buffer_copy(data)
3702                s.sendall(byteslike)
3703                self.assertEqual(s.read(), data)
3704
3705            # Make sure sendmsg et al are disallowed to avoid
3706            # inadvertent disclosure of data and/or corruption
3707            # of the encrypted data stream
3708            self.assertRaises(NotImplementedError, s.dup)
3709            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3710            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3711            self.assertRaises(NotImplementedError,
3712                              s.recvmsg_into, [bytearray(100)])
3713            s.write(b"over\n")
3714
3715            self.assertRaises(ValueError, s.recv, -1)
3716            self.assertRaises(ValueError, s.read, -1)
3717
3718            s.close()
3719
3720    def test_recv_zero(self):
3721        server = ThreadedEchoServer(CERTFILE)
3722        server.__enter__()
3723        self.addCleanup(server.__exit__, None, None)
3724        s = socket.create_connection((HOST, server.port))
3725        self.addCleanup(s.close)
3726        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3727        self.addCleanup(s.close)
3728
3729        # recv/read(0) should return no data
3730        s.send(b"data")
3731        self.assertEqual(s.recv(0), b"")
3732        self.assertEqual(s.read(0), b"")
3733        self.assertEqual(s.read(), b"data")
3734
3735        # Should not block if the other end sends no data
3736        s.setblocking(False)
3737        self.assertEqual(s.recv(0), b"")
3738        self.assertEqual(s.recv_into(bytearray()), 0)
3739
3740    def test_nonblocking_send(self):
3741        server = ThreadedEchoServer(CERTFILE,
3742                                    certreqs=ssl.CERT_NONE,
3743                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3744                                    cacerts=CERTFILE,
3745                                    chatty=True,
3746                                    connectionchatty=False)
3747        with server:
3748            s = test_wrap_socket(socket.socket(),
3749                                server_side=False,
3750                                certfile=CERTFILE,
3751                                ca_certs=CERTFILE,
3752                                cert_reqs=ssl.CERT_NONE)
3753            s.connect((HOST, server.port))
3754            s.setblocking(False)
3755
3756            # If we keep sending data, at some point the buffers
3757            # will be full and the call will block
3758            buf = bytearray(8192)
3759            def fill_buffer():
3760                while True:
3761                    s.send(buf)
3762            self.assertRaises((ssl.SSLWantWriteError,
3763                               ssl.SSLWantReadError), fill_buffer)
3764
3765            # Now read all the output and discard it
3766            s.setblocking(True)
3767            s.close()
3768
3769    def test_handshake_timeout(self):
3770        # Issue #5103: SSL handshake must respect the socket timeout
3771        server = socket.socket(socket.AF_INET)
3772        host = "127.0.0.1"
3773        port = socket_helper.bind_port(server)
3774        started = threading.Event()
3775        finish = False
3776
3777        def serve():
3778            server.listen()
3779            started.set()
3780            conns = []
3781            while not finish:
3782                r, w, e = select.select([server], [], [], 0.1)
3783                if server in r:
3784                    # Let the socket hang around rather than having
3785                    # it closed by garbage collection.
3786                    conns.append(server.accept()[0])
3787            for sock in conns:
3788                sock.close()
3789
3790        t = threading.Thread(target=serve)
3791        t.start()
3792        started.wait()
3793
3794        try:
3795            try:
3796                c = socket.socket(socket.AF_INET)
3797                c.settimeout(0.2)
3798                c.connect((host, port))
3799                # Will attempt handshake and time out
3800                self.assertRaisesRegex(TimeoutError, "timed out",
3801                                       test_wrap_socket, c)
3802            finally:
3803                c.close()
3804            try:
3805                c = socket.socket(socket.AF_INET)
3806                c = test_wrap_socket(c)
3807                c.settimeout(0.2)
3808                # Will attempt handshake and time out
3809                self.assertRaisesRegex(TimeoutError, "timed out",
3810                                       c.connect, (host, port))
3811            finally:
3812                c.close()
3813        finally:
3814            finish = True
3815            t.join()
3816            server.close()
3817
3818    def test_server_accept(self):
3819        # Issue #16357: accept() on a SSLSocket created through
3820        # SSLContext.wrap_socket().
3821        client_ctx, server_ctx, hostname = testing_context()
3822        server = socket.socket(socket.AF_INET)
3823        host = "127.0.0.1"
3824        port = socket_helper.bind_port(server)
3825        server = server_ctx.wrap_socket(server, server_side=True)
3826        self.assertTrue(server.server_side)
3827
3828        evt = threading.Event()
3829        remote = None
3830        peer = None
3831        def serve():
3832            nonlocal remote, peer
3833            server.listen()
3834            # Block on the accept and wait on the connection to close.
3835            evt.set()
3836            remote, peer = server.accept()
3837            remote.send(remote.recv(4))
3838
3839        t = threading.Thread(target=serve)
3840        t.start()
3841        # Client wait until server setup and perform a connect.
3842        evt.wait()
3843        client = client_ctx.wrap_socket(
3844            socket.socket(), server_hostname=hostname
3845        )
3846        client.connect((hostname, port))
3847        client.send(b'data')
3848        client.recv()
3849        client_addr = client.getsockname()
3850        client.close()
3851        t.join()
3852        remote.close()
3853        server.close()
3854        # Sanity checks.
3855        self.assertIsInstance(remote, ssl.SSLSocket)
3856        self.assertEqual(peer, client_addr)
3857
3858    def test_getpeercert_enotconn(self):
3859        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3860        context.check_hostname = False
3861        with context.wrap_socket(socket.socket()) as sock:
3862            with self.assertRaises(OSError) as cm:
3863                sock.getpeercert()
3864            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3865
3866    def test_do_handshake_enotconn(self):
3867        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3868        context.check_hostname = False
3869        with context.wrap_socket(socket.socket()) as sock:
3870            with self.assertRaises(OSError) as cm:
3871                sock.do_handshake()
3872            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3873
3874    def test_no_shared_ciphers(self):
3875        client_context, server_context, hostname = testing_context()
3876        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3877        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3878        # Force different suites on client and server
3879        client_context.set_ciphers("AES128")
3880        server_context.set_ciphers("AES256")
3881        with ThreadedEchoServer(context=server_context) as server:
3882            with client_context.wrap_socket(socket.socket(),
3883                                            server_hostname=hostname) as s:
3884                with self.assertRaises(OSError):
3885                    s.connect((HOST, server.port))
3886        self.assertIn("no shared cipher", server.conn_errors[0])
3887
3888    def test_version_basic(self):
3889        """
3890        Basic tests for SSLSocket.version().
3891        More tests are done in the test_protocol_*() methods.
3892        """
3893        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3894        context.check_hostname = False
3895        context.verify_mode = ssl.CERT_NONE
3896        with ThreadedEchoServer(CERTFILE,
3897                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3898                                chatty=False) as server:
3899            with context.wrap_socket(socket.socket()) as s:
3900                self.assertIs(s.version(), None)
3901                self.assertIs(s._sslobj, None)
3902                s.connect((HOST, server.port))
3903                self.assertEqual(s.version(), 'TLSv1.3')
3904            self.assertIs(s._sslobj, None)
3905            self.assertIs(s.version(), None)
3906
3907    @requires_tls_version('TLSv1_3')
3908    def test_tls1_3(self):
3909        client_context, server_context, hostname = testing_context()
3910        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3911        with ThreadedEchoServer(context=server_context) as server:
3912            with client_context.wrap_socket(socket.socket(),
3913                                            server_hostname=hostname) as s:
3914                s.connect((HOST, server.port))
3915                self.assertIn(s.cipher()[0], {
3916                    'TLS_AES_256_GCM_SHA384',
3917                    'TLS_CHACHA20_POLY1305_SHA256',
3918                    'TLS_AES_128_GCM_SHA256',
3919                })
3920                self.assertEqual(s.version(), 'TLSv1.3')
3921
3922    @requires_tls_version('TLSv1_2')
3923    @requires_tls_version('TLSv1')
3924    @ignore_deprecation
3925    def test_min_max_version_tlsv1_2(self):
3926        client_context, server_context, hostname = testing_context()
3927        # client TLSv1.0 to 1.2
3928        client_context.minimum_version = ssl.TLSVersion.TLSv1
3929        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3930        # server only TLSv1.2
3931        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3932        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3933
3934        with ThreadedEchoServer(context=server_context) as server:
3935            with client_context.wrap_socket(socket.socket(),
3936                                            server_hostname=hostname) as s:
3937                s.connect((HOST, server.port))
3938                self.assertEqual(s.version(), 'TLSv1.2')
3939
3940    @requires_tls_version('TLSv1_1')
3941    @ignore_deprecation
3942    def test_min_max_version_tlsv1_1(self):
3943        client_context, server_context, hostname = testing_context()
3944        # client 1.0 to 1.2, server 1.0 to 1.1
3945        client_context.minimum_version = ssl.TLSVersion.TLSv1
3946        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3947        server_context.minimum_version = ssl.TLSVersion.TLSv1
3948        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3949        seclevel_workaround(client_context, server_context)
3950
3951        with ThreadedEchoServer(context=server_context) as server:
3952            with client_context.wrap_socket(socket.socket(),
3953                                            server_hostname=hostname) as s:
3954                s.connect((HOST, server.port))
3955                self.assertEqual(s.version(), 'TLSv1.1')
3956
3957    @requires_tls_version('TLSv1_2')
3958    @requires_tls_version('TLSv1')
3959    @ignore_deprecation
3960    def test_min_max_version_mismatch(self):
3961        client_context, server_context, hostname = testing_context()
3962        # client 1.0, server 1.2 (mismatch)
3963        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3964        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3965        client_context.maximum_version = ssl.TLSVersion.TLSv1
3966        client_context.minimum_version = ssl.TLSVersion.TLSv1
3967        seclevel_workaround(client_context, server_context)
3968
3969        with ThreadedEchoServer(context=server_context) as server:
3970            with client_context.wrap_socket(socket.socket(),
3971                                            server_hostname=hostname) as s:
3972                with self.assertRaises(ssl.SSLError) as e:
3973                    s.connect((HOST, server.port))
3974                self.assertIn("alert", str(e.exception))
3975
3976    @requires_tls_version('SSLv3')
3977    def test_min_max_version_sslv3(self):
3978        client_context, server_context, hostname = testing_context()
3979        server_context.minimum_version = ssl.TLSVersion.SSLv3
3980        client_context.minimum_version = ssl.TLSVersion.SSLv3
3981        client_context.maximum_version = ssl.TLSVersion.SSLv3
3982        seclevel_workaround(client_context, server_context)
3983
3984        with ThreadedEchoServer(context=server_context) as server:
3985            with client_context.wrap_socket(socket.socket(),
3986                                            server_hostname=hostname) as s:
3987                s.connect((HOST, server.port))
3988                self.assertEqual(s.version(), 'SSLv3')
3989
3990    def test_default_ecdh_curve(self):
3991        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
3992        # should be enabled by default on SSL contexts.
3993        client_context, server_context, hostname = testing_context()
3994        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
3995        # cipher name.
3996        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3997        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
3998        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
3999        # our default cipher list should prefer ECDH-based ciphers
4000        # automatically.
4001        with ThreadedEchoServer(context=server_context) as server:
4002            with client_context.wrap_socket(socket.socket(),
4003                                            server_hostname=hostname) as s:
4004                s.connect((HOST, server.port))
4005                self.assertIn("ECDH", s.cipher()[0])
4006
4007    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
4008                         "'tls-unique' channel binding not available")
4009    def test_tls_unique_channel_binding(self):
4010        """Test tls-unique channel binding."""
4011        if support.verbose:
4012            sys.stdout.write("\n")
4013
4014        client_context, server_context, hostname = testing_context()
4015
4016        server = ThreadedEchoServer(context=server_context,
4017                                    chatty=True,
4018                                    connectionchatty=False)
4019
4020        with server:
4021            with client_context.wrap_socket(
4022                    socket.socket(),
4023                    server_hostname=hostname) as s:
4024                s.connect((HOST, server.port))
4025                # get the data
4026                cb_data = s.get_channel_binding("tls-unique")
4027                if support.verbose:
4028                    sys.stdout.write(
4029                        " got channel binding data: {0!r}\n".format(cb_data))
4030
4031                # check if it is sane
4032                self.assertIsNotNone(cb_data)
4033                if s.version() == 'TLSv1.3':
4034                    self.assertEqual(len(cb_data), 48)
4035                else:
4036                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4037
4038                # and compare with the peers version
4039                s.write(b"CB tls-unique\n")
4040                peer_data_repr = s.read().strip()
4041                self.assertEqual(peer_data_repr,
4042                                 repr(cb_data).encode("us-ascii"))
4043
4044            # now, again
4045            with client_context.wrap_socket(
4046                    socket.socket(),
4047                    server_hostname=hostname) as s:
4048                s.connect((HOST, server.port))
4049                new_cb_data = s.get_channel_binding("tls-unique")
4050                if support.verbose:
4051                    sys.stdout.write(
4052                        "got another channel binding data: {0!r}\n".format(
4053                            new_cb_data)
4054                    )
4055                # is it really unique
4056                self.assertNotEqual(cb_data, new_cb_data)
4057                self.assertIsNotNone(cb_data)
4058                if s.version() == 'TLSv1.3':
4059                    self.assertEqual(len(cb_data), 48)
4060                else:
4061                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4062                s.write(b"CB tls-unique\n")
4063                peer_data_repr = s.read().strip()
4064                self.assertEqual(peer_data_repr,
4065                                 repr(new_cb_data).encode("us-ascii"))
4066
4067    def test_compression(self):
4068        client_context, server_context, hostname = testing_context()
4069        stats = server_params_test(client_context, server_context,
4070                                   chatty=True, connectionchatty=True,
4071                                   sni_name=hostname)
4072        if support.verbose:
4073            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
4074        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
4075
4076    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
4077                         "ssl.OP_NO_COMPRESSION needed for this test")
4078    def test_compression_disabled(self):
4079        client_context, server_context, hostname = testing_context()
4080        client_context.options |= ssl.OP_NO_COMPRESSION
4081        server_context.options |= ssl.OP_NO_COMPRESSION
4082        stats = server_params_test(client_context, server_context,
4083                                   chatty=True, connectionchatty=True,
4084                                   sni_name=hostname)
4085        self.assertIs(stats['compression'], None)
4086
4087    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4088    def test_dh_params(self):
4089        # Check we can get a connection with ephemeral Diffie-Hellman
4090        client_context, server_context, hostname = testing_context()
4091        # test scenario needs TLS <= 1.2
4092        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4093        server_context.load_dh_params(DHFILE)
4094        server_context.set_ciphers("kEDH")
4095        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
4096        stats = server_params_test(client_context, server_context,
4097                                   chatty=True, connectionchatty=True,
4098                                   sni_name=hostname)
4099        cipher = stats["cipher"][0]
4100        parts = cipher.split("-")
4101        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
4102            self.fail("Non-DH cipher: " + cipher[0])
4103
4104    def test_ecdh_curve(self):
4105        # server secp384r1, client auto
4106        client_context, server_context, hostname = testing_context()
4107
4108        server_context.set_ecdh_curve("secp384r1")
4109        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4110        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4111        stats = server_params_test(client_context, server_context,
4112                                   chatty=True, connectionchatty=True,
4113                                   sni_name=hostname)
4114
4115        # server auto, client secp384r1
4116        client_context, server_context, hostname = testing_context()
4117        client_context.set_ecdh_curve("secp384r1")
4118        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4119        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4120        stats = server_params_test(client_context, server_context,
4121                                   chatty=True, connectionchatty=True,
4122                                   sni_name=hostname)
4123
4124        # server / client curve mismatch
4125        client_context, server_context, hostname = testing_context()
4126        client_context.set_ecdh_curve("prime256v1")
4127        server_context.set_ecdh_curve("secp384r1")
4128        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4129        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
4130        with self.assertRaises(ssl.SSLError):
4131            server_params_test(client_context, server_context,
4132                               chatty=True, connectionchatty=True,
4133                               sni_name=hostname)
4134
4135    def test_selected_alpn_protocol(self):
4136        # selected_alpn_protocol() is None unless ALPN is used.
4137        client_context, server_context, hostname = testing_context()
4138        stats = server_params_test(client_context, server_context,
4139                                   chatty=True, connectionchatty=True,
4140                                   sni_name=hostname)
4141        self.assertIs(stats['client_alpn_protocol'], None)
4142
4143    def test_selected_alpn_protocol_if_server_uses_alpn(self):
4144        # selected_alpn_protocol() is None unless ALPN is used by the client.
4145        client_context, server_context, hostname = testing_context()
4146        server_context.set_alpn_protocols(['foo', 'bar'])
4147        stats = server_params_test(client_context, server_context,
4148                                   chatty=True, connectionchatty=True,
4149                                   sni_name=hostname)
4150        self.assertIs(stats['client_alpn_protocol'], None)
4151
4152    def test_alpn_protocols(self):
4153        server_protocols = ['foo', 'bar', 'milkshake']
4154        protocol_tests = [
4155            (['foo', 'bar'], 'foo'),
4156            (['bar', 'foo'], 'foo'),
4157            (['milkshake'], 'milkshake'),
4158            (['http/3.0', 'http/4.0'], None)
4159        ]
4160        for client_protocols, expected in protocol_tests:
4161            client_context, server_context, hostname = testing_context()
4162            server_context.set_alpn_protocols(server_protocols)
4163            client_context.set_alpn_protocols(client_protocols)
4164
4165            try:
4166                stats = server_params_test(client_context,
4167                                           server_context,
4168                                           chatty=True,
4169                                           connectionchatty=True,
4170                                           sni_name=hostname)
4171            except ssl.SSLError as e:
4172                stats = e
4173
4174            msg = "failed trying %s (s) and %s (c).\n" \
4175                "was expecting %s, but got %%s from the %%s" \
4176                    % (str(server_protocols), str(client_protocols),
4177                        str(expected))
4178            client_result = stats['client_alpn_protocol']
4179            self.assertEqual(client_result, expected,
4180                             msg % (client_result, "client"))
4181            server_result = stats['server_alpn_protocols'][-1] \
4182                if len(stats['server_alpn_protocols']) else 'nothing'
4183            self.assertEqual(server_result, expected,
4184                             msg % (server_result, "server"))
4185
4186    def test_npn_protocols(self):
4187        assert not ssl.HAS_NPN
4188
4189    def sni_contexts(self):
4190        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4191        server_context.load_cert_chain(SIGNED_CERTFILE)
4192        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4193        other_context.load_cert_chain(SIGNED_CERTFILE2)
4194        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4195        client_context.load_verify_locations(SIGNING_CA)
4196        return server_context, other_context, client_context
4197
4198    def check_common_name(self, stats, name):
4199        cert = stats['peercert']
4200        self.assertIn((('commonName', name),), cert['subject'])
4201
4202    def test_sni_callback(self):
4203        calls = []
4204        server_context, other_context, client_context = self.sni_contexts()
4205
4206        client_context.check_hostname = False
4207
4208        def servername_cb(ssl_sock, server_name, initial_context):
4209            calls.append((server_name, initial_context))
4210            if server_name is not None:
4211                ssl_sock.context = other_context
4212        server_context.set_servername_callback(servername_cb)
4213
4214        stats = server_params_test(client_context, server_context,
4215                                   chatty=True,
4216                                   sni_name='supermessage')
4217        # The hostname was fetched properly, and the certificate was
4218        # changed for the connection.
4219        self.assertEqual(calls, [("supermessage", server_context)])
4220        # CERTFILE4 was selected
4221        self.check_common_name(stats, 'fakehostname')
4222
4223        calls = []
4224        # The callback is called with server_name=None
4225        stats = server_params_test(client_context, server_context,
4226                                   chatty=True,
4227                                   sni_name=None)
4228        self.assertEqual(calls, [(None, server_context)])
4229        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4230
4231        # Check disabling the callback
4232        calls = []
4233        server_context.set_servername_callback(None)
4234
4235        stats = server_params_test(client_context, server_context,
4236                                   chatty=True,
4237                                   sni_name='notfunny')
4238        # Certificate didn't change
4239        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4240        self.assertEqual(calls, [])
4241
4242    def test_sni_callback_alert(self):
4243        # Returning a TLS alert is reflected to the connecting client
4244        server_context, other_context, client_context = self.sni_contexts()
4245
4246        def cb_returning_alert(ssl_sock, server_name, initial_context):
4247            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4248        server_context.set_servername_callback(cb_returning_alert)
4249        with self.assertRaises(ssl.SSLError) as cm:
4250            stats = server_params_test(client_context, server_context,
4251                                       chatty=False,
4252                                       sni_name='supermessage')
4253        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4254
4255    def test_sni_callback_raising(self):
4256        # Raising fails the connection with a TLS handshake failure alert.
4257        server_context, other_context, client_context = self.sni_contexts()
4258
4259        def cb_raising(ssl_sock, server_name, initial_context):
4260            1/0
4261        server_context.set_servername_callback(cb_raising)
4262
4263        with support.catch_unraisable_exception() as catch:
4264            with self.assertRaises(ssl.SSLError) as cm:
4265                stats = server_params_test(client_context, server_context,
4266                                           chatty=False,
4267                                           sni_name='supermessage')
4268
4269            self.assertEqual(cm.exception.reason,
4270                             'SSLV3_ALERT_HANDSHAKE_FAILURE')
4271            self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
4272
4273    def test_sni_callback_wrong_return_type(self):
4274        # Returning the wrong return type terminates the TLS connection
4275        # with an internal error alert.
4276        server_context, other_context, client_context = self.sni_contexts()
4277
4278        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4279            return "foo"
4280        server_context.set_servername_callback(cb_wrong_return_type)
4281
4282        with support.catch_unraisable_exception() as catch:
4283            with self.assertRaises(ssl.SSLError) as cm:
4284                stats = server_params_test(client_context, server_context,
4285                                           chatty=False,
4286                                           sni_name='supermessage')
4287
4288
4289            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4290            self.assertEqual(catch.unraisable.exc_type, TypeError)
4291
4292    def test_shared_ciphers(self):
4293        client_context, server_context, hostname = testing_context()
4294        client_context.set_ciphers("AES128:AES256")
4295        server_context.set_ciphers("AES256")
4296        expected_algs = [
4297            "AES256", "AES-256",
4298            # TLS 1.3 ciphers are always enabled
4299            "TLS_CHACHA20", "TLS_AES",
4300        ]
4301
4302        stats = server_params_test(client_context, server_context,
4303                                   sni_name=hostname)
4304        ciphers = stats['server_shared_ciphers'][0]
4305        self.assertGreater(len(ciphers), 0)
4306        for name, tls_version, bits in ciphers:
4307            if not any(alg in name for alg in expected_algs):
4308                self.fail(name)
4309
4310    def test_read_write_after_close_raises_valuerror(self):
4311        client_context, server_context, hostname = testing_context()
4312        server = ThreadedEchoServer(context=server_context, chatty=False)
4313
4314        with server:
4315            s = client_context.wrap_socket(socket.socket(),
4316                                           server_hostname=hostname)
4317            s.connect((HOST, server.port))
4318            s.close()
4319
4320            self.assertRaises(ValueError, s.read, 1024)
4321            self.assertRaises(ValueError, s.write, b'hello')
4322
4323    def test_sendfile(self):
4324        TEST_DATA = b"x" * 512
4325        with open(os_helper.TESTFN, 'wb') as f:
4326            f.write(TEST_DATA)
4327        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4328        client_context, server_context, hostname = testing_context()
4329        server = ThreadedEchoServer(context=server_context, chatty=False)
4330        with server:
4331            with client_context.wrap_socket(socket.socket(),
4332                                            server_hostname=hostname) as s:
4333                s.connect((HOST, server.port))
4334                with open(os_helper.TESTFN, 'rb') as file:
4335                    s.sendfile(file)
4336                    self.assertEqual(s.recv(1024), TEST_DATA)
4337
4338    def test_session(self):
4339        client_context, server_context, hostname = testing_context()
4340        # TODO: sessions aren't compatible with TLSv1.3 yet
4341        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4342
4343        # first connection without session
4344        stats = server_params_test(client_context, server_context,
4345                                   sni_name=hostname)
4346        session = stats['session']
4347        self.assertTrue(session.id)
4348        self.assertGreater(session.time, 0)
4349        self.assertGreater(session.timeout, 0)
4350        self.assertTrue(session.has_ticket)
4351        self.assertGreater(session.ticket_lifetime_hint, 0)
4352        self.assertFalse(stats['session_reused'])
4353        sess_stat = server_context.session_stats()
4354        self.assertEqual(sess_stat['accept'], 1)
4355        self.assertEqual(sess_stat['hits'], 0)
4356
4357        # reuse session
4358        stats = server_params_test(client_context, server_context,
4359                                   session=session, sni_name=hostname)
4360        sess_stat = server_context.session_stats()
4361        self.assertEqual(sess_stat['accept'], 2)
4362        self.assertEqual(sess_stat['hits'], 1)
4363        self.assertTrue(stats['session_reused'])
4364        session2 = stats['session']
4365        self.assertEqual(session2.id, session.id)
4366        self.assertEqual(session2, session)
4367        self.assertIsNot(session2, session)
4368        self.assertGreaterEqual(session2.time, session.time)
4369        self.assertGreaterEqual(session2.timeout, session.timeout)
4370
4371        # another one without session
4372        stats = server_params_test(client_context, server_context,
4373                                   sni_name=hostname)
4374        self.assertFalse(stats['session_reused'])
4375        session3 = stats['session']
4376        self.assertNotEqual(session3.id, session.id)
4377        self.assertNotEqual(session3, session)
4378        sess_stat = server_context.session_stats()
4379        self.assertEqual(sess_stat['accept'], 3)
4380        self.assertEqual(sess_stat['hits'], 1)
4381
4382        # reuse session again
4383        stats = server_params_test(client_context, server_context,
4384                                   session=session, sni_name=hostname)
4385        self.assertTrue(stats['session_reused'])
4386        session4 = stats['session']
4387        self.assertEqual(session4.id, session.id)
4388        self.assertEqual(session4, session)
4389        self.assertGreaterEqual(session4.time, session.time)
4390        self.assertGreaterEqual(session4.timeout, session.timeout)
4391        sess_stat = server_context.session_stats()
4392        self.assertEqual(sess_stat['accept'], 4)
4393        self.assertEqual(sess_stat['hits'], 2)
4394
4395    def test_session_handling(self):
4396        client_context, server_context, hostname = testing_context()
4397        client_context2, _, _ = testing_context()
4398
4399        # TODO: session reuse does not work with TLSv1.3
4400        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4401        client_context2.maximum_version = ssl.TLSVersion.TLSv1_2
4402
4403        server = ThreadedEchoServer(context=server_context, chatty=False)
4404        with server:
4405            with client_context.wrap_socket(socket.socket(),
4406                                            server_hostname=hostname) as s:
4407                # session is None before handshake
4408                self.assertEqual(s.session, None)
4409                self.assertEqual(s.session_reused, None)
4410                s.connect((HOST, server.port))
4411                session = s.session
4412                self.assertTrue(session)
4413                with self.assertRaises(TypeError) as e:
4414                    s.session = object
4415                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4416
4417            with client_context.wrap_socket(socket.socket(),
4418                                            server_hostname=hostname) as s:
4419                s.connect((HOST, server.port))
4420                # cannot set session after handshake
4421                with self.assertRaises(ValueError) as e:
4422                    s.session = session
4423                self.assertEqual(str(e.exception),
4424                                 'Cannot set session after handshake.')
4425
4426            with client_context.wrap_socket(socket.socket(),
4427                                            server_hostname=hostname) as s:
4428                # can set session before handshake and before the
4429                # connection was established
4430                s.session = session
4431                s.connect((HOST, server.port))
4432                self.assertEqual(s.session.id, session.id)
4433                self.assertEqual(s.session, session)
4434                self.assertEqual(s.session_reused, True)
4435
4436            with client_context2.wrap_socket(socket.socket(),
4437                                             server_hostname=hostname) as s:
4438                # cannot re-use session with a different SSLContext
4439                with self.assertRaises(ValueError) as e:
4440                    s.session = session
4441                    s.connect((HOST, server.port))
4442                self.assertEqual(str(e.exception),
4443                                 'Session refers to a different SSLContext.')
4444
4445
4446@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
4447class TestPostHandshakeAuth(unittest.TestCase):
4448    def test_pha_setter(self):
4449        protocols = [
4450            ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4451        ]
4452        for protocol in protocols:
4453            ctx = ssl.SSLContext(protocol)
4454            self.assertEqual(ctx.post_handshake_auth, False)
4455
4456            ctx.post_handshake_auth = True
4457            self.assertEqual(ctx.post_handshake_auth, True)
4458
4459            ctx.verify_mode = ssl.CERT_REQUIRED
4460            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4461            self.assertEqual(ctx.post_handshake_auth, True)
4462
4463            ctx.post_handshake_auth = False
4464            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4465            self.assertEqual(ctx.post_handshake_auth, False)
4466
4467            ctx.verify_mode = ssl.CERT_OPTIONAL
4468            ctx.post_handshake_auth = True
4469            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4470            self.assertEqual(ctx.post_handshake_auth, True)
4471
4472    def test_pha_required(self):
4473        client_context, server_context, hostname = testing_context()
4474        server_context.post_handshake_auth = True
4475        server_context.verify_mode = ssl.CERT_REQUIRED
4476        client_context.post_handshake_auth = True
4477        client_context.load_cert_chain(SIGNED_CERTFILE)
4478
4479        server = ThreadedEchoServer(context=server_context, chatty=False)
4480        with server:
4481            with client_context.wrap_socket(socket.socket(),
4482                                            server_hostname=hostname) as s:
4483                s.connect((HOST, server.port))
4484                s.write(b'HASCERT')
4485                self.assertEqual(s.recv(1024), b'FALSE\n')
4486                s.write(b'PHA')
4487                self.assertEqual(s.recv(1024), b'OK\n')
4488                s.write(b'HASCERT')
4489                self.assertEqual(s.recv(1024), b'TRUE\n')
4490                # PHA method just returns true when cert is already available
4491                s.write(b'PHA')
4492                self.assertEqual(s.recv(1024), b'OK\n')
4493                s.write(b'GETCERT')
4494                cert_text = s.recv(4096).decode('us-ascii')
4495                self.assertIn('Python Software Foundation CA', cert_text)
4496
4497    def test_pha_required_nocert(self):
4498        client_context, server_context, hostname = testing_context()
4499        server_context.post_handshake_auth = True
4500        server_context.verify_mode = ssl.CERT_REQUIRED
4501        client_context.post_handshake_auth = True
4502
4503        def msg_cb(conn, direction, version, content_type, msg_type, data):
4504            if support.verbose and content_type == _TLSContentType.ALERT:
4505                info = (conn, direction, version, content_type, msg_type, data)
4506                sys.stdout.write(f"TLS: {info!r}\n")
4507
4508        server_context._msg_callback = msg_cb
4509        client_context._msg_callback = msg_cb
4510
4511        server = ThreadedEchoServer(context=server_context, chatty=True)
4512        with server:
4513            with client_context.wrap_socket(socket.socket(),
4514                                            server_hostname=hostname,
4515                                            suppress_ragged_eofs=False) as s:
4516                s.connect((HOST, server.port))
4517                s.write(b'PHA')
4518                # test sometimes fails with EOF error. Test passes as long as
4519                # server aborts connection with an error.
4520                with self.assertRaisesRegex(
4521                    ssl.SSLError,
4522                    '(certificate required|EOF occurred)'
4523                ):
4524                    # receive CertificateRequest
4525                    data = s.recv(1024)
4526                    self.assertEqual(data, b'OK\n')
4527
4528                    # send empty Certificate + Finish
4529                    s.write(b'HASCERT')
4530
4531                    # receive alert
4532                    s.recv(1024)
4533
4534    def test_pha_optional(self):
4535        if support.verbose:
4536            sys.stdout.write("\n")
4537
4538        client_context, server_context, hostname = testing_context()
4539        server_context.post_handshake_auth = True
4540        server_context.verify_mode = ssl.CERT_REQUIRED
4541        client_context.post_handshake_auth = True
4542        client_context.load_cert_chain(SIGNED_CERTFILE)
4543
4544        # check CERT_OPTIONAL
4545        server_context.verify_mode = ssl.CERT_OPTIONAL
4546        server = ThreadedEchoServer(context=server_context, chatty=False)
4547        with server:
4548            with client_context.wrap_socket(socket.socket(),
4549                                            server_hostname=hostname) as s:
4550                s.connect((HOST, server.port))
4551                s.write(b'HASCERT')
4552                self.assertEqual(s.recv(1024), b'FALSE\n')
4553                s.write(b'PHA')
4554                self.assertEqual(s.recv(1024), b'OK\n')
4555                s.write(b'HASCERT')
4556                self.assertEqual(s.recv(1024), b'TRUE\n')
4557
4558    def test_pha_optional_nocert(self):
4559        if support.verbose:
4560            sys.stdout.write("\n")
4561
4562        client_context, server_context, hostname = testing_context()
4563        server_context.post_handshake_auth = True
4564        server_context.verify_mode = ssl.CERT_OPTIONAL
4565        client_context.post_handshake_auth = True
4566
4567        server = ThreadedEchoServer(context=server_context, chatty=False)
4568        with server:
4569            with client_context.wrap_socket(socket.socket(),
4570                                            server_hostname=hostname) as s:
4571                s.connect((HOST, server.port))
4572                s.write(b'HASCERT')
4573                self.assertEqual(s.recv(1024), b'FALSE\n')
4574                s.write(b'PHA')
4575                self.assertEqual(s.recv(1024), b'OK\n')
4576                # optional doesn't fail when client does not have a cert
4577                s.write(b'HASCERT')
4578                self.assertEqual(s.recv(1024), b'FALSE\n')
4579
4580    def test_pha_no_pha_client(self):
4581        client_context, server_context, hostname = testing_context()
4582        server_context.post_handshake_auth = True
4583        server_context.verify_mode = ssl.CERT_REQUIRED
4584        client_context.load_cert_chain(SIGNED_CERTFILE)
4585
4586        server = ThreadedEchoServer(context=server_context, chatty=False)
4587        with server:
4588            with client_context.wrap_socket(socket.socket(),
4589                                            server_hostname=hostname) as s:
4590                s.connect((HOST, server.port))
4591                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4592                    s.verify_client_post_handshake()
4593                s.write(b'PHA')
4594                self.assertIn(b'extension not received', s.recv(1024))
4595
4596    def test_pha_no_pha_server(self):
4597        # server doesn't have PHA enabled, cert is requested in handshake
4598        client_context, server_context, hostname = testing_context()
4599        server_context.verify_mode = ssl.CERT_REQUIRED
4600        client_context.post_handshake_auth = True
4601        client_context.load_cert_chain(SIGNED_CERTFILE)
4602
4603        server = ThreadedEchoServer(context=server_context, chatty=False)
4604        with server:
4605            with client_context.wrap_socket(socket.socket(),
4606                                            server_hostname=hostname) as s:
4607                s.connect((HOST, server.port))
4608                s.write(b'HASCERT')
4609                self.assertEqual(s.recv(1024), b'TRUE\n')
4610                # PHA doesn't fail if there is already a cert
4611                s.write(b'PHA')
4612                self.assertEqual(s.recv(1024), b'OK\n')
4613                s.write(b'HASCERT')
4614                self.assertEqual(s.recv(1024), b'TRUE\n')
4615
4616    def test_pha_not_tls13(self):
4617        # TLS 1.2
4618        client_context, server_context, hostname = testing_context()
4619        server_context.verify_mode = ssl.CERT_REQUIRED
4620        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4621        client_context.post_handshake_auth = True
4622        client_context.load_cert_chain(SIGNED_CERTFILE)
4623
4624        server = ThreadedEchoServer(context=server_context, chatty=False)
4625        with server:
4626            with client_context.wrap_socket(socket.socket(),
4627                                            server_hostname=hostname) as s:
4628                s.connect((HOST, server.port))
4629                # PHA fails for TLS != 1.3
4630                s.write(b'PHA')
4631                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4632
4633    def test_bpo37428_pha_cert_none(self):
4634        # verify that post_handshake_auth does not implicitly enable cert
4635        # validation.
4636        hostname = SIGNED_CERTFILE_HOSTNAME
4637        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4638        client_context.post_handshake_auth = True
4639        client_context.load_cert_chain(SIGNED_CERTFILE)
4640        # no cert validation and CA on client side
4641        client_context.check_hostname = False
4642        client_context.verify_mode = ssl.CERT_NONE
4643
4644        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4645        server_context.load_cert_chain(SIGNED_CERTFILE)
4646        server_context.load_verify_locations(SIGNING_CA)
4647        server_context.post_handshake_auth = True
4648        server_context.verify_mode = ssl.CERT_REQUIRED
4649
4650        server = ThreadedEchoServer(context=server_context, chatty=False)
4651        with server:
4652            with client_context.wrap_socket(socket.socket(),
4653                                            server_hostname=hostname) as s:
4654                s.connect((HOST, server.port))
4655                s.write(b'HASCERT')
4656                self.assertEqual(s.recv(1024), b'FALSE\n')
4657                s.write(b'PHA')
4658                self.assertEqual(s.recv(1024), b'OK\n')
4659                s.write(b'HASCERT')
4660                self.assertEqual(s.recv(1024), b'TRUE\n')
4661                # server cert has not been validated
4662                self.assertEqual(s.getpeercert(), {})
4663
4664    def test_internal_chain_client(self):
4665        client_context, server_context, hostname = testing_context(
4666            server_chain=False
4667        )
4668        server = ThreadedEchoServer(context=server_context, chatty=False)
4669        with server:
4670            with client_context.wrap_socket(
4671                socket.socket(),
4672                server_hostname=hostname
4673            ) as s:
4674                s.connect((HOST, server.port))
4675                vc = s._sslobj.get_verified_chain()
4676                self.assertEqual(len(vc), 2)
4677                ee, ca = vc
4678                uvc = s._sslobj.get_unverified_chain()
4679                self.assertEqual(len(uvc), 1)
4680
4681                self.assertEqual(ee, uvc[0])
4682                self.assertEqual(hash(ee), hash(uvc[0]))
4683                self.assertEqual(repr(ee), repr(uvc[0]))
4684
4685                self.assertNotEqual(ee, ca)
4686                self.assertNotEqual(hash(ee), hash(ca))
4687                self.assertNotEqual(repr(ee), repr(ca))
4688                self.assertNotEqual(ee.get_info(), ca.get_info())
4689                self.assertIn("CN=localhost", repr(ee))
4690                self.assertIn("CN=our-ca-server", repr(ca))
4691
4692                pem = ee.public_bytes(_ssl.ENCODING_PEM)
4693                der = ee.public_bytes(_ssl.ENCODING_DER)
4694                self.assertIsInstance(pem, str)
4695                self.assertIn("-----BEGIN CERTIFICATE-----", pem)
4696                self.assertIsInstance(der, bytes)
4697                self.assertEqual(
4698                    ssl.PEM_cert_to_DER_cert(pem), der
4699                )
4700
4701    def test_internal_chain_server(self):
4702        client_context, server_context, hostname = testing_context()
4703        client_context.load_cert_chain(SIGNED_CERTFILE)
4704        server_context.verify_mode = ssl.CERT_REQUIRED
4705        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
4706
4707        server = ThreadedEchoServer(context=server_context, chatty=False)
4708        with server:
4709            with client_context.wrap_socket(
4710                socket.socket(),
4711                server_hostname=hostname
4712            ) as s:
4713                s.connect((HOST, server.port))
4714                s.write(b'VERIFIEDCHAIN\n')
4715                res = s.recv(1024)
4716                self.assertEqual(res, b'\x02\n')
4717                s.write(b'UNVERIFIEDCHAIN\n')
4718                res = s.recv(1024)
4719                self.assertEqual(res, b'\x02\n')
4720
4721
4722HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
4723requires_keylog = unittest.skipUnless(
4724    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
4725
4726class TestSSLDebug(unittest.TestCase):
4727
4728    def keylog_lines(self, fname=os_helper.TESTFN):
4729        with open(fname) as f:
4730            return len(list(f))
4731
4732    @requires_keylog
4733    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4734    def test_keylog_defaults(self):
4735        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4736        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4737        self.assertEqual(ctx.keylog_filename, None)
4738
4739        self.assertFalse(os.path.isfile(os_helper.TESTFN))
4740        ctx.keylog_filename = os_helper.TESTFN
4741        self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4742        self.assertTrue(os.path.isfile(os_helper.TESTFN))
4743        self.assertEqual(self.keylog_lines(), 1)
4744
4745        ctx.keylog_filename = None
4746        self.assertEqual(ctx.keylog_filename, None)
4747
4748        with self.assertRaises((IsADirectoryError, PermissionError)):
4749            # Windows raises PermissionError
4750            ctx.keylog_filename = os.path.dirname(
4751                os.path.abspath(os_helper.TESTFN))
4752
4753        with self.assertRaises(TypeError):
4754            ctx.keylog_filename = 1
4755
4756    @requires_keylog
4757    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4758    def test_keylog_filename(self):
4759        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4760        client_context, server_context, hostname = testing_context()
4761
4762        client_context.keylog_filename = os_helper.TESTFN
4763        server = ThreadedEchoServer(context=server_context, chatty=False)
4764        with server:
4765            with client_context.wrap_socket(socket.socket(),
4766                                            server_hostname=hostname) as s:
4767                s.connect((HOST, server.port))
4768        # header, 5 lines for TLS 1.3
4769        self.assertEqual(self.keylog_lines(), 6)
4770
4771        client_context.keylog_filename = None
4772        server_context.keylog_filename = os_helper.TESTFN
4773        server = ThreadedEchoServer(context=server_context, chatty=False)
4774        with server:
4775            with client_context.wrap_socket(socket.socket(),
4776                                            server_hostname=hostname) as s:
4777                s.connect((HOST, server.port))
4778        self.assertGreaterEqual(self.keylog_lines(), 11)
4779
4780        client_context.keylog_filename = os_helper.TESTFN
4781        server_context.keylog_filename = os_helper.TESTFN
4782        server = ThreadedEchoServer(context=server_context, chatty=False)
4783        with server:
4784            with client_context.wrap_socket(socket.socket(),
4785                                            server_hostname=hostname) as s:
4786                s.connect((HOST, server.port))
4787        self.assertGreaterEqual(self.keylog_lines(), 21)
4788
4789        client_context.keylog_filename = None
4790        server_context.keylog_filename = None
4791
4792    @requires_keylog
4793    @unittest.skipIf(sys.flags.ignore_environment,
4794                     "test is not compatible with ignore_environment")
4795    @unittest.skipIf(Py_DEBUG_WIN32, "Avoid mixing debug/release CRT on Windows")
4796    def test_keylog_env(self):
4797        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
4798        with unittest.mock.patch.dict(os.environ):
4799            os.environ['SSLKEYLOGFILE'] = os_helper.TESTFN
4800            self.assertEqual(os.environ['SSLKEYLOGFILE'], os_helper.TESTFN)
4801
4802            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4803            self.assertEqual(ctx.keylog_filename, None)
4804
4805            ctx = ssl.create_default_context()
4806            self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4807
4808            ctx = ssl._create_stdlib_context()
4809            self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
4810
4811    def test_msg_callback(self):
4812        client_context, server_context, hostname = testing_context()
4813
4814        def msg_cb(conn, direction, version, content_type, msg_type, data):
4815            pass
4816
4817        self.assertIs(client_context._msg_callback, None)
4818        client_context._msg_callback = msg_cb
4819        self.assertIs(client_context._msg_callback, msg_cb)
4820        with self.assertRaises(TypeError):
4821            client_context._msg_callback = object()
4822
4823    def test_msg_callback_tls12(self):
4824        client_context, server_context, hostname = testing_context()
4825        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4826
4827        msg = []
4828
4829        def msg_cb(conn, direction, version, content_type, msg_type, data):
4830            self.assertIsInstance(conn, ssl.SSLSocket)
4831            self.assertIsInstance(data, bytes)
4832            self.assertIn(direction, {'read', 'write'})
4833            msg.append((direction, version, content_type, msg_type))
4834
4835        client_context._msg_callback = msg_cb
4836
4837        server = ThreadedEchoServer(context=server_context, chatty=False)
4838        with server:
4839            with client_context.wrap_socket(socket.socket(),
4840                                            server_hostname=hostname) as s:
4841                s.connect((HOST, server.port))
4842
4843        self.assertIn(
4844            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
4845             _TLSMessageType.SERVER_KEY_EXCHANGE),
4846            msg
4847        )
4848        self.assertIn(
4849            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
4850             _TLSMessageType.CHANGE_CIPHER_SPEC),
4851            msg
4852        )
4853
4854    def test_msg_callback_deadlock_bpo43577(self):
4855        client_context, server_context, hostname = testing_context()
4856        server_context2 = testing_context()[1]
4857
4858        def msg_cb(conn, direction, version, content_type, msg_type, data):
4859            pass
4860
4861        def sni_cb(sock, servername, ctx):
4862            sock.context = server_context2
4863
4864        server_context._msg_callback = msg_cb
4865        server_context.sni_callback = sni_cb
4866
4867        server = ThreadedEchoServer(context=server_context, chatty=False)
4868        with server:
4869            with client_context.wrap_socket(socket.socket(),
4870                                            server_hostname=hostname) as s:
4871                s.connect((HOST, server.port))
4872            with client_context.wrap_socket(socket.socket(),
4873                                            server_hostname=hostname) as s:
4874                s.connect((HOST, server.port))
4875
4876
4877def setUpModule():
4878    if support.verbose:
4879        plats = {
4880            'Mac': platform.mac_ver,
4881            'Windows': platform.win32_ver,
4882        }
4883        for name, func in plats.items():
4884            plat = func()
4885            if plat and plat[0]:
4886                plat = '%s %r' % (name, plat)
4887                break
4888        else:
4889            plat = repr(platform.platform())
4890        print("test_ssl: testing with %r %r" %
4891            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
4892        print("          under %s" % plat)
4893        print("          HAS_SNI = %r" % ssl.HAS_SNI)
4894        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
4895        try:
4896            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
4897        except AttributeError:
4898            pass
4899
4900    for filename in [
4901        CERTFILE, BYTES_CERTFILE,
4902        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
4903        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
4904        BADCERT, BADKEY, EMPTYCERT]:
4905        if not os.path.exists(filename):
4906            raise support.TestFailed("Can't read certificate file %r" % filename)
4907
4908    thread_info = threading_helper.threading_setup()
4909    unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
4910
4911
4912if __name__ == "__main__":
4913    unittest.main()
4914