• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test the support for SSL and sockets
2
3import sys
4import unittest
5import unittest.mock
6from test import support
7import socket
8import select
9import time
10import datetime
11import gc
12import os
13import errno
14import pprint
15import urllib.request
16import threading
17import traceback
18import asyncore
19import weakref
20import platform
21import sysconfig
22import functools
23try:
24    import ctypes
25except ImportError:
26    ctypes = None
27
28ssl = support.import_module("ssl")
29
30from ssl import TLSVersion, _TLSContentType, _TLSMessageType
31
32PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
33HOST = support.HOST
34IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
35IS_OPENSSL_1_1_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
36IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
37PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
38
39PROTOCOL_TO_TLS_VERSION = {}
40for proto, ver in (
41    ("PROTOCOL_SSLv23", "SSLv3"),
42    ("PROTOCOL_TLSv1", "TLSv1"),
43    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
44):
45    try:
46        proto = getattr(ssl, proto)
47        ver = getattr(ssl.TLSVersion, ver)
48    except AttributeError:
49        continue
50    PROTOCOL_TO_TLS_VERSION[proto] = ver
51
52def data_file(*name):
53    return os.path.join(os.path.dirname(__file__), *name)
54
55# The custom key and certificate files used in test_ssl are generated
56# using Lib/test/make_ssl_certs.py.
57# Other certificates are simply fetched from the Internet servers they
58# are meant to authenticate.
59
60CERTFILE = data_file("keycert.pem")
61BYTES_CERTFILE = os.fsencode(CERTFILE)
62ONLYCERT = data_file("ssl_cert.pem")
63ONLYKEY = data_file("ssl_key.pem")
64BYTES_ONLYCERT = os.fsencode(ONLYCERT)
65BYTES_ONLYKEY = os.fsencode(ONLYKEY)
66CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
67ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
68KEY_PASSWORD = "somepass"
69CAPATH = data_file("capath")
70BYTES_CAPATH = os.fsencode(CAPATH)
71CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
72CAFILE_CACERT = data_file("capath", "5ed36f99.0")
73
74CERTFILE_INFO = {
75    'issuer': ((('countryName', 'XY'),),
76               (('localityName', 'Castle Anthrax'),),
77               (('organizationName', 'Python Software Foundation'),),
78               (('commonName', 'localhost'),)),
79    'notAfter': 'Aug 26 14:23:15 2028 GMT',
80    'notBefore': 'Aug 29 14:23:15 2018 GMT',
81    'serialNumber': '98A7CF88C74A32ED',
82    'subject': ((('countryName', 'XY'),),
83             (('localityName', 'Castle Anthrax'),),
84             (('organizationName', 'Python Software Foundation'),),
85             (('commonName', 'localhost'),)),
86    'subjectAltName': (('DNS', 'localhost'),),
87    'version': 3
88}
89
90# empty CRL
91CRLFILE = data_file("revocation.crl")
92
93# Two keys and certs signed by the same CA (for SNI tests)
94SIGNED_CERTFILE = data_file("keycert3.pem")
95SIGNED_CERTFILE_HOSTNAME = 'localhost'
96
97SIGNED_CERTFILE_INFO = {
98    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
99    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
100    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
101    'issuer': ((('countryName', 'XY'),),
102            (('organizationName', 'Python Software Foundation CA'),),
103            (('commonName', 'our-ca-server'),)),
104    'notAfter': 'Jul  7 14:23:16 2028 GMT',
105    'notBefore': 'Aug 29 14:23:16 2018 GMT',
106    'serialNumber': 'CB2D80995A69525C',
107    'subject': ((('countryName', 'XY'),),
108             (('localityName', 'Castle Anthrax'),),
109             (('organizationName', 'Python Software Foundation'),),
110             (('commonName', 'localhost'),)),
111    'subjectAltName': (('DNS', 'localhost'),),
112    'version': 3
113}
114
115SIGNED_CERTFILE2 = data_file("keycert4.pem")
116SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
117SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
118SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
119
120# Same certificate as pycacert.pem, but without extra text in file
121SIGNING_CA = data_file("capath", "ceff1710.0")
122# cert with all kinds of subject alt names
123ALLSANFILE = data_file("allsans.pem")
124IDNSANSFILE = data_file("idnsans.pem")
125
126REMOTE_HOST = "self-signed.pythontest.net"
127
128EMPTYCERT = data_file("nullcert.pem")
129BADCERT = data_file("badcert.pem")
130NONEXISTINGCERT = data_file("XXXnonexisting.pem")
131BADKEY = data_file("badkey.pem")
132NOKIACERT = data_file("nokia.pem")
133NULLBYTECERT = data_file("nullbytecert.pem")
134TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
135
136DHFILE = data_file("ffdh3072.pem")
137BYTES_DHFILE = os.fsencode(DHFILE)
138
139# Not defined in all versions of OpenSSL
140OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
141OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
142OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
143OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
144OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
145
146
147def has_tls_protocol(protocol):
148    """Check if a TLS protocol is available and enabled
149
150    :param protocol: enum ssl._SSLMethod member or name
151    :return: bool
152    """
153    if isinstance(protocol, str):
154        assert protocol.startswith('PROTOCOL_')
155        protocol = getattr(ssl, protocol, None)
156        if protocol is None:
157            return False
158    if protocol in {
159        ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER,
160        ssl.PROTOCOL_TLS_CLIENT
161    }:
162        # auto-negotiate protocols are always available
163        return True
164    name = protocol.name
165    return has_tls_version(name[len('PROTOCOL_'):])
166
167
168@functools.lru_cache
169def has_tls_version(version):
170    """Check if a TLS/SSL version is enabled
171
172    :param version: TLS version name or ssl.TLSVersion member
173    :return: bool
174    """
175    if version == "SSLv2":
176        # never supported and not even in TLSVersion enum
177        return False
178
179    if isinstance(version, str):
180        version = ssl.TLSVersion.__members__[version]
181
182    # check compile time flags like ssl.HAS_TLSv1_2
183    if not getattr(ssl, f'HAS_{version.name}'):
184        return False
185
186    # check runtime and dynamic crypto policy settings. A TLS version may
187    # be compiled in but disabled by a policy or config option.
188    ctx = ssl.SSLContext()
189    if (
190            hasattr(ctx, 'minimum_version') and
191            ctx.minimum_version != ssl.TLSVersion.MINIMUM_SUPPORTED and
192            version < ctx.minimum_version
193    ):
194        return False
195    if (
196        hasattr(ctx, 'maximum_version') and
197        ctx.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
198        version > ctx.maximum_version
199    ):
200        return False
201
202    return True
203
204
205def requires_tls_version(version):
206    """Decorator to skip tests when a required TLS version is not available
207
208    :param version: TLS version name or ssl.TLSVersion member
209    :return:
210    """
211    def decorator(func):
212        @functools.wraps(func)
213        def wrapper(*args, **kw):
214            if not has_tls_version(version):
215                raise unittest.SkipTest(f"{version} is not available.")
216            else:
217                return func(*args, **kw)
218        return wrapper
219    return decorator
220
221
222requires_minimum_version = unittest.skipUnless(
223    hasattr(ssl.SSLContext, 'minimum_version'),
224    "required OpenSSL >= 1.1.0g"
225)
226
227
228def handle_error(prefix):
229    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
230    if support.verbose:
231        sys.stdout.write(prefix + exc_format)
232
233def can_clear_options():
234    # 0.9.8m or higher
235    return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
236
237def no_sslv2_implies_sslv3_hello():
238    # 0.9.7h or higher
239    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
240
241def have_verify_flags():
242    # 0.9.8 or higher
243    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
244
245def _have_secp_curves():
246    if not ssl.HAS_ECDH:
247        return False
248    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
249    try:
250        ctx.set_ecdh_curve("secp384r1")
251    except ValueError:
252        return False
253    else:
254        return True
255
256
257HAVE_SECP_CURVES = _have_secp_curves()
258
259
260def utc_offset(): #NOTE: ignore issues like #1647654
261    # local time = utc time + utc offset
262    if time.daylight and time.localtime().tm_isdst > 0:
263        return -time.altzone  # seconds
264    return -time.timezone
265
266def asn1time(cert_time):
267    # Some versions of OpenSSL ignore seconds, see #18207
268    # 0.9.8.i
269    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
270        fmt = "%b %d %H:%M:%S %Y GMT"
271        dt = datetime.datetime.strptime(cert_time, fmt)
272        dt = dt.replace(second=0)
273        cert_time = dt.strftime(fmt)
274        # %d adds leading zero but ASN1_TIME_print() uses leading space
275        if cert_time[4] == "0":
276            cert_time = cert_time[:4] + " " + cert_time[5:]
277
278    return cert_time
279
280needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
281
282
283def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *,
284                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
285                     ciphers=None, certfile=None, keyfile=None,
286                     **kwargs):
287    context = ssl.SSLContext(ssl_version)
288    if cert_reqs is not None:
289        if cert_reqs == ssl.CERT_NONE:
290            context.check_hostname = False
291        context.verify_mode = cert_reqs
292    if ca_certs is not None:
293        context.load_verify_locations(ca_certs)
294    if certfile is not None or keyfile is not None:
295        context.load_cert_chain(certfile, keyfile)
296    if ciphers is not None:
297        context.set_ciphers(ciphers)
298    return context.wrap_socket(sock, **kwargs)
299
300
301def testing_context(server_cert=SIGNED_CERTFILE):
302    """Create context
303
304    client_context, server_context, hostname = testing_context()
305    """
306    if server_cert == SIGNED_CERTFILE:
307        hostname = SIGNED_CERTFILE_HOSTNAME
308    elif server_cert == SIGNED_CERTFILE2:
309        hostname = SIGNED_CERTFILE2_HOSTNAME
310    else:
311        raise ValueError(server_cert)
312
313    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
314    client_context.load_verify_locations(SIGNING_CA)
315
316    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
317    server_context.load_cert_chain(server_cert)
318    server_context.load_verify_locations(SIGNING_CA)
319
320    return client_context, server_context, hostname
321
322
323class BasicSocketTests(unittest.TestCase):
324
325    def test_constants(self):
326        ssl.CERT_NONE
327        ssl.CERT_OPTIONAL
328        ssl.CERT_REQUIRED
329        ssl.OP_CIPHER_SERVER_PREFERENCE
330        ssl.OP_SINGLE_DH_USE
331        if ssl.HAS_ECDH:
332            ssl.OP_SINGLE_ECDH_USE
333        if ssl.OPENSSL_VERSION_INFO >= (1, 0):
334            ssl.OP_NO_COMPRESSION
335        self.assertIn(ssl.HAS_SNI, {True, False})
336        self.assertIn(ssl.HAS_ECDH, {True, False})
337        ssl.OP_NO_SSLv2
338        ssl.OP_NO_SSLv3
339        ssl.OP_NO_TLSv1
340        ssl.OP_NO_TLSv1_3
341        if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1):
342            ssl.OP_NO_TLSv1_1
343            ssl.OP_NO_TLSv1_2
344        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
345
346    def test_private_init(self):
347        with self.assertRaisesRegex(TypeError, "public constructor"):
348            with socket.socket() as s:
349                ssl.SSLSocket(s)
350
351    def test_str_for_enums(self):
352        # Make sure that the PROTOCOL_* constants have enum-like string
353        # reprs.
354        proto = ssl.PROTOCOL_TLS
355        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
356        ctx = ssl.SSLContext(proto)
357        self.assertIs(ctx.protocol, proto)
358
359    def test_random(self):
360        v = ssl.RAND_status()
361        if support.verbose:
362            sys.stdout.write("\n RAND_status is %d (%s)\n"
363                             % (v, (v and "sufficient randomness") or
364                                "insufficient randomness"))
365
366        data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
367        self.assertEqual(len(data), 16)
368        self.assertEqual(is_cryptographic, v == 1)
369        if v:
370            data = ssl.RAND_bytes(16)
371            self.assertEqual(len(data), 16)
372        else:
373            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
374
375        # negative num is invalid
376        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
377        self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
378
379        if hasattr(ssl, 'RAND_egd'):
380            self.assertRaises(TypeError, ssl.RAND_egd, 1)
381            self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
382        ssl.RAND_add("this is a random string", 75.0)
383        ssl.RAND_add(b"this is a random bytes object", 75.0)
384        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
385
386    @unittest.skipUnless(os.name == 'posix', 'requires posix')
387    def test_random_fork(self):
388        status = ssl.RAND_status()
389        if not status:
390            self.fail("OpenSSL's PRNG has insufficient randomness")
391
392        rfd, wfd = os.pipe()
393        pid = os.fork()
394        if pid == 0:
395            try:
396                os.close(rfd)
397                child_random = ssl.RAND_pseudo_bytes(16)[0]
398                self.assertEqual(len(child_random), 16)
399                os.write(wfd, child_random)
400                os.close(wfd)
401            except BaseException:
402                os._exit(1)
403            else:
404                os._exit(0)
405        else:
406            os.close(wfd)
407            self.addCleanup(os.close, rfd)
408            _, status = os.waitpid(pid, 0)
409            self.assertEqual(status, 0)
410
411            child_random = os.read(rfd, 16)
412            self.assertEqual(len(child_random), 16)
413            parent_random = ssl.RAND_pseudo_bytes(16)[0]
414            self.assertEqual(len(parent_random), 16)
415
416            self.assertNotEqual(child_random, parent_random)
417
418    maxDiff = None
419
420    def test_parse_cert(self):
421        # note that this uses an 'unofficial' function in _ssl.c,
422        # provided solely for this test, to exercise the certificate
423        # parsing code
424        self.assertEqual(
425            ssl._ssl._test_decode_cert(CERTFILE),
426            CERTFILE_INFO
427        )
428        self.assertEqual(
429            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
430            SIGNED_CERTFILE_INFO
431        )
432
433        # Issue #13034: the subjectAltName in some certificates
434        # (notably projects.developer.nokia.com:443) wasn't parsed
435        p = ssl._ssl._test_decode_cert(NOKIACERT)
436        if support.verbose:
437            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
438        self.assertEqual(p['subjectAltName'],
439                         (('DNS', 'projects.developer.nokia.com'),
440                          ('DNS', 'projects.forum.nokia.com'))
441                        )
442        # extra OCSP and AIA fields
443        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
444        self.assertEqual(p['caIssuers'],
445                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
446        self.assertEqual(p['crlDistributionPoints'],
447                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
448
449    def test_parse_cert_CVE_2019_5010(self):
450        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
451        if support.verbose:
452            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
453        self.assertEqual(
454            p,
455            {
456                'issuer': (
457                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
458                'notAfter': 'Jun 14 18:00:58 2028 GMT',
459                'notBefore': 'Jun 18 18:00:58 2018 GMT',
460                'serialNumber': '02',
461                'subject': ((('countryName', 'UK'),),
462                            (('commonName',
463                              'codenomicon-vm-2.test.lal.cisco.com'),)),
464                'subjectAltName': (
465                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
466                'version': 3
467            }
468        )
469
470    def test_parse_cert_CVE_2013_4238(self):
471        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
472        if support.verbose:
473            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
474        subject = ((('countryName', 'US'),),
475                   (('stateOrProvinceName', 'Oregon'),),
476                   (('localityName', 'Beaverton'),),
477                   (('organizationName', 'Python Software Foundation'),),
478                   (('organizationalUnitName', 'Python Core Development'),),
479                   (('commonName', 'null.python.org\x00example.org'),),
480                   (('emailAddress', 'python-dev@python.org'),))
481        self.assertEqual(p['subject'], subject)
482        self.assertEqual(p['issuer'], subject)
483        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
484            san = (('DNS', 'altnull.python.org\x00example.com'),
485                   ('email', 'null@python.org\x00user@example.org'),
486                   ('URI', 'http://null.python.org\x00http://example.org'),
487                   ('IP Address', '192.0.2.1'),
488                   ('IP Address', '2001:DB8:0:0:0:0:0:1'))
489        else:
490            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
491            san = (('DNS', 'altnull.python.org\x00example.com'),
492                   ('email', 'null@python.org\x00user@example.org'),
493                   ('URI', 'http://null.python.org\x00http://example.org'),
494                   ('IP Address', '192.0.2.1'),
495                   ('IP Address', '<invalid>'))
496
497        self.assertEqual(p['subjectAltName'], san)
498
499    def test_parse_all_sans(self):
500        p = ssl._ssl._test_decode_cert(ALLSANFILE)
501        self.assertEqual(p['subjectAltName'],
502            (
503                ('DNS', 'allsans'),
504                ('othername', '<unsupported>'),
505                ('othername', '<unsupported>'),
506                ('email', 'user@example.org'),
507                ('DNS', 'www.example.org'),
508                ('DirName',
509                    ((('countryName', 'XY'),),
510                    (('localityName', 'Castle Anthrax'),),
511                    (('organizationName', 'Python Software Foundation'),),
512                    (('commonName', 'dirname example'),))),
513                ('URI', 'https://www.python.org/'),
514                ('IP Address', '127.0.0.1'),
515                ('IP Address', '0:0:0:0:0:0:0:1'),
516                ('Registered ID', '1.2.3.4.5')
517            )
518        )
519
520    def test_DER_to_PEM(self):
521        with open(CAFILE_CACERT, 'r') as f:
522            pem = f.read()
523        d1 = ssl.PEM_cert_to_DER_cert(pem)
524        p2 = ssl.DER_cert_to_PEM_cert(d1)
525        d2 = ssl.PEM_cert_to_DER_cert(p2)
526        self.assertEqual(d1, d2)
527        if not p2.startswith(ssl.PEM_HEADER + '\n'):
528            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
529        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
530            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
531
532    def test_openssl_version(self):
533        n = ssl.OPENSSL_VERSION_NUMBER
534        t = ssl.OPENSSL_VERSION_INFO
535        s = ssl.OPENSSL_VERSION
536        self.assertIsInstance(n, int)
537        self.assertIsInstance(t, tuple)
538        self.assertIsInstance(s, str)
539        # Some sanity checks follow
540        # >= 0.9
541        self.assertGreaterEqual(n, 0x900000)
542        # < 4.0
543        self.assertLess(n, 0x40000000)
544        major, minor, fix, patch, status = t
545        self.assertGreaterEqual(major, 1)
546        self.assertLess(major, 4)
547        self.assertGreaterEqual(minor, 0)
548        self.assertLess(minor, 256)
549        self.assertGreaterEqual(fix, 0)
550        self.assertLess(fix, 256)
551        self.assertGreaterEqual(patch, 0)
552        self.assertLessEqual(patch, 63)
553        self.assertGreaterEqual(status, 0)
554        self.assertLessEqual(status, 15)
555        # Version string as returned by {Open,Libre}SSL, the format might change
556        if IS_LIBRESSL:
557            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
558                            (s, t, hex(n)))
559        else:
560            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
561                            (s, t, hex(n)))
562
563    @support.cpython_only
564    def test_refcycle(self):
565        # Issue #7943: an SSL object doesn't create reference cycles with
566        # itself.
567        s = socket.socket(socket.AF_INET)
568        ss = test_wrap_socket(s)
569        wr = weakref.ref(ss)
570        with support.check_warnings(("", ResourceWarning)):
571            del ss
572        self.assertEqual(wr(), None)
573
574    def test_wrapped_unconnected(self):
575        # Methods on an unconnected SSLSocket propagate the original
576        # OSError raise by the underlying socket object.
577        s = socket.socket(socket.AF_INET)
578        with test_wrap_socket(s) as ss:
579            self.assertRaises(OSError, ss.recv, 1)
580            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
581            self.assertRaises(OSError, ss.recvfrom, 1)
582            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
583            self.assertRaises(OSError, ss.send, b'x')
584            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
585            self.assertRaises(NotImplementedError, ss.dup)
586            self.assertRaises(NotImplementedError, ss.sendmsg,
587                              [b'x'], (), 0, ('0.0.0.0', 0))
588            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
589            self.assertRaises(NotImplementedError, ss.recvmsg_into,
590                              [bytearray(100)])
591
592    def test_timeout(self):
593        # Issue #8524: when creating an SSL socket, the timeout of the
594        # original socket should be retained.
595        for timeout in (None, 0.0, 5.0):
596            s = socket.socket(socket.AF_INET)
597            s.settimeout(timeout)
598            with test_wrap_socket(s) as ss:
599                self.assertEqual(timeout, ss.gettimeout())
600
601    def test_errors_sslwrap(self):
602        sock = socket.socket()
603        self.assertRaisesRegex(ValueError,
604                        "certfile must be specified",
605                        ssl.wrap_socket, sock, keyfile=CERTFILE)
606        self.assertRaisesRegex(ValueError,
607                        "certfile must be specified for server-side operations",
608                        ssl.wrap_socket, sock, server_side=True)
609        self.assertRaisesRegex(ValueError,
610                        "certfile must be specified for server-side operations",
611                         ssl.wrap_socket, sock, server_side=True, certfile="")
612        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
613            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
614                                     s.connect, (HOST, 8080))
615        with self.assertRaises(OSError) as cm:
616            with socket.socket() as sock:
617                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
618        self.assertEqual(cm.exception.errno, errno.ENOENT)
619        with self.assertRaises(OSError) as cm:
620            with socket.socket() as sock:
621                ssl.wrap_socket(sock,
622                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
623        self.assertEqual(cm.exception.errno, errno.ENOENT)
624        with self.assertRaises(OSError) as cm:
625            with socket.socket() as sock:
626                ssl.wrap_socket(sock,
627                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
628        self.assertEqual(cm.exception.errno, errno.ENOENT)
629
630    def bad_cert_test(self, certfile):
631        """Check that trying to use the given client certificate fails"""
632        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
633                                   certfile)
634        sock = socket.socket()
635        self.addCleanup(sock.close)
636        with self.assertRaises(ssl.SSLError):
637            test_wrap_socket(sock,
638                             certfile=certfile)
639
640    def test_empty_cert(self):
641        """Wrapping with an empty cert file"""
642        self.bad_cert_test("nullcert.pem")
643
644    def test_malformed_cert(self):
645        """Wrapping with a badly formatted certificate (syntax error)"""
646        self.bad_cert_test("badcert.pem")
647
648    def test_malformed_key(self):
649        """Wrapping with a badly formatted key (syntax error)"""
650        self.bad_cert_test("badkey.pem")
651
652    def test_match_hostname(self):
653        def ok(cert, hostname):
654            ssl.match_hostname(cert, hostname)
655        def fail(cert, hostname):
656            self.assertRaises(ssl.CertificateError,
657                              ssl.match_hostname, cert, hostname)
658
659        # -- Hostname matching --
660
661        cert = {'subject': ((('commonName', 'example.com'),),)}
662        ok(cert, 'example.com')
663        ok(cert, 'ExAmple.cOm')
664        fail(cert, 'www.example.com')
665        fail(cert, '.example.com')
666        fail(cert, 'example.org')
667        fail(cert, 'exampleXcom')
668
669        cert = {'subject': ((('commonName', '*.a.com'),),)}
670        ok(cert, 'foo.a.com')
671        fail(cert, 'bar.foo.a.com')
672        fail(cert, 'a.com')
673        fail(cert, 'Xa.com')
674        fail(cert, '.a.com')
675
676        # only match wildcards when they are the only thing
677        # in left-most segment
678        cert = {'subject': ((('commonName', 'f*.com'),),)}
679        fail(cert, 'foo.com')
680        fail(cert, 'f.com')
681        fail(cert, 'bar.com')
682        fail(cert, 'foo.a.com')
683        fail(cert, 'bar.foo.com')
684
685        # NULL bytes are bad, CVE-2013-4073
686        cert = {'subject': ((('commonName',
687                              'null.python.org\x00example.org'),),)}
688        ok(cert, 'null.python.org\x00example.org') # or raise an error?
689        fail(cert, 'example.org')
690        fail(cert, 'null.python.org')
691
692        # error cases with wildcards
693        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
694        fail(cert, 'bar.foo.a.com')
695        fail(cert, 'a.com')
696        fail(cert, 'Xa.com')
697        fail(cert, '.a.com')
698
699        cert = {'subject': ((('commonName', 'a.*.com'),),)}
700        fail(cert, 'a.foo.com')
701        fail(cert, 'a..com')
702        fail(cert, 'a.com')
703
704        # wildcard doesn't match IDNA prefix 'xn--'
705        idna = 'püthon.python.org'.encode("idna").decode("ascii")
706        cert = {'subject': ((('commonName', idna),),)}
707        ok(cert, idna)
708        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
709        fail(cert, idna)
710        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
711        fail(cert, idna)
712
713        # wildcard in first fragment and  IDNA A-labels in sequent fragments
714        # are supported.
715        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
716        cert = {'subject': ((('commonName', idna),),)}
717        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
718        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
719        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
720        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
721
722        # Slightly fake real-world example
723        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
724                'subject': ((('commonName', 'linuxfrz.org'),),),
725                'subjectAltName': (('DNS', 'linuxfr.org'),
726                                   ('DNS', 'linuxfr.com'),
727                                   ('othername', '<unsupported>'))}
728        ok(cert, 'linuxfr.org')
729        ok(cert, 'linuxfr.com')
730        # Not a "DNS" entry
731        fail(cert, '<unsupported>')
732        # When there is a subjectAltName, commonName isn't used
733        fail(cert, 'linuxfrz.org')
734
735        # A pristine real-world example
736        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
737                'subject': ((('countryName', 'US'),),
738                            (('stateOrProvinceName', 'California'),),
739                            (('localityName', 'Mountain View'),),
740                            (('organizationName', 'Google Inc'),),
741                            (('commonName', 'mail.google.com'),))}
742        ok(cert, 'mail.google.com')
743        fail(cert, 'gmail.com')
744        # Only commonName is considered
745        fail(cert, 'California')
746
747        # -- IPv4 matching --
748        cert = {'subject': ((('commonName', 'example.com'),),),
749                'subjectAltName': (('DNS', 'example.com'),
750                                   ('IP Address', '10.11.12.13'),
751                                   ('IP Address', '14.15.16.17'),
752                                   ('IP Address', '127.0.0.1'))}
753        ok(cert, '10.11.12.13')
754        ok(cert, '14.15.16.17')
755        # socket.inet_ntoa(socket.inet_aton('127.1')) == '127.0.0.1'
756        fail(cert, '127.1')
757        fail(cert, '14.15.16.17 ')
758        fail(cert, '14.15.16.17 extra data')
759        fail(cert, '14.15.16.18')
760        fail(cert, 'example.net')
761
762        # -- IPv6 matching --
763        if support.IPV6_ENABLED:
764            cert = {'subject': ((('commonName', 'example.com'),),),
765                    'subjectAltName': (
766                        ('DNS', 'example.com'),
767                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
768                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
769            ok(cert, '2001::cafe')
770            ok(cert, '2003::baba')
771            fail(cert, '2003::baba ')
772            fail(cert, '2003::baba extra data')
773            fail(cert, '2003::bebe')
774            fail(cert, 'example.net')
775
776        # -- Miscellaneous --
777
778        # Neither commonName nor subjectAltName
779        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
780                'subject': ((('countryName', 'US'),),
781                            (('stateOrProvinceName', 'California'),),
782                            (('localityName', 'Mountain View'),),
783                            (('organizationName', 'Google Inc'),))}
784        fail(cert, 'mail.google.com')
785
786        # No DNS entry in subjectAltName but a commonName
787        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
788                'subject': ((('countryName', 'US'),),
789                            (('stateOrProvinceName', 'California'),),
790                            (('localityName', 'Mountain View'),),
791                            (('commonName', 'mail.google.com'),)),
792                'subjectAltName': (('othername', 'blabla'), )}
793        ok(cert, 'mail.google.com')
794
795        # No DNS entry subjectAltName and no commonName
796        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
797                'subject': ((('countryName', 'US'),),
798                            (('stateOrProvinceName', 'California'),),
799                            (('localityName', 'Mountain View'),),
800                            (('organizationName', 'Google Inc'),)),
801                'subjectAltName': (('othername', 'blabla'),)}
802        fail(cert, 'google.com')
803
804        # Empty cert / no cert
805        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
806        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
807
808        # Issue #17980: avoid denials of service by refusing more than one
809        # wildcard per fragment.
810        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
811        with self.assertRaisesRegex(
812                ssl.CertificateError,
813                "partial wildcards in leftmost label are not supported"):
814            ssl.match_hostname(cert, 'axxb.example.com')
815
816        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
817        with self.assertRaisesRegex(
818                ssl.CertificateError,
819                "wildcard can only be present in the leftmost label"):
820            ssl.match_hostname(cert, 'www.sub.example.com')
821
822        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
823        with self.assertRaisesRegex(
824                ssl.CertificateError,
825                "too many wildcards"):
826            ssl.match_hostname(cert, 'axxbxxc.example.com')
827
828        cert = {'subject': ((('commonName', '*'),),)}
829        with self.assertRaisesRegex(
830                ssl.CertificateError,
831                "sole wildcard without additional labels are not support"):
832            ssl.match_hostname(cert, 'host')
833
834        cert = {'subject': ((('commonName', '*.com'),),)}
835        with self.assertRaisesRegex(
836                ssl.CertificateError,
837                r"hostname 'com' doesn't match '\*.com'"):
838            ssl.match_hostname(cert, 'com')
839
840        # extra checks for _inet_paton()
841        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
842            with self.assertRaises(ValueError):
843                ssl._inet_paton(invalid)
844        for ipaddr in ['127.0.0.1', '192.168.0.1']:
845            self.assertTrue(ssl._inet_paton(ipaddr))
846        if support.IPV6_ENABLED:
847            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
848                self.assertTrue(ssl._inet_paton(ipaddr))
849
850    def test_server_side(self):
851        # server_hostname doesn't work for server sockets
852        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
853        with socket.socket() as sock:
854            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
855                              server_hostname="some.hostname")
856
857    def test_unknown_channel_binding(self):
858        # should raise ValueError for unknown type
859        s = socket.create_server(('127.0.0.1', 0))
860        c = socket.socket(socket.AF_INET)
861        c.connect(s.getsockname())
862        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
863            with self.assertRaises(ValueError):
864                ss.get_channel_binding("unknown-type")
865        s.close()
866
867    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
868                         "'tls-unique' channel binding not available")
869    def test_tls_unique_channel_binding(self):
870        # unconnected should return None for known type
871        s = socket.socket(socket.AF_INET)
872        with test_wrap_socket(s) as ss:
873            self.assertIsNone(ss.get_channel_binding("tls-unique"))
874        # the same for server-side
875        s = socket.socket(socket.AF_INET)
876        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
877            self.assertIsNone(ss.get_channel_binding("tls-unique"))
878
879    def test_dealloc_warn(self):
880        ss = test_wrap_socket(socket.socket(socket.AF_INET))
881        r = repr(ss)
882        with self.assertWarns(ResourceWarning) as cm:
883            ss = None
884            support.gc_collect()
885        self.assertIn(r, str(cm.warning.args[0]))
886
887    def test_get_default_verify_paths(self):
888        paths = ssl.get_default_verify_paths()
889        self.assertEqual(len(paths), 6)
890        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
891
892        with support.EnvironmentVarGuard() as env:
893            env["SSL_CERT_DIR"] = CAPATH
894            env["SSL_CERT_FILE"] = CERTFILE
895            paths = ssl.get_default_verify_paths()
896            self.assertEqual(paths.cafile, CERTFILE)
897            self.assertEqual(paths.capath, CAPATH)
898
899    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
900    def test_enum_certificates(self):
901        self.assertTrue(ssl.enum_certificates("CA"))
902        self.assertTrue(ssl.enum_certificates("ROOT"))
903
904        self.assertRaises(TypeError, ssl.enum_certificates)
905        self.assertRaises(WindowsError, ssl.enum_certificates, "")
906
907        trust_oids = set()
908        for storename in ("CA", "ROOT"):
909            store = ssl.enum_certificates(storename)
910            self.assertIsInstance(store, list)
911            for element in store:
912                self.assertIsInstance(element, tuple)
913                self.assertEqual(len(element), 3)
914                cert, enc, trust = element
915                self.assertIsInstance(cert, bytes)
916                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
917                self.assertIsInstance(trust, (frozenset, set, bool))
918                if isinstance(trust, (frozenset, set)):
919                    trust_oids.update(trust)
920
921        serverAuth = "1.3.6.1.5.5.7.3.1"
922        self.assertIn(serverAuth, trust_oids)
923
924    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
925    def test_enum_crls(self):
926        self.assertTrue(ssl.enum_crls("CA"))
927        self.assertRaises(TypeError, ssl.enum_crls)
928        self.assertRaises(WindowsError, ssl.enum_crls, "")
929
930        crls = ssl.enum_crls("CA")
931        self.assertIsInstance(crls, list)
932        for element in crls:
933            self.assertIsInstance(element, tuple)
934            self.assertEqual(len(element), 2)
935            self.assertIsInstance(element[0], bytes)
936            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
937
938
939    def test_asn1object(self):
940        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
941                    '1.3.6.1.5.5.7.3.1')
942
943        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
944        self.assertEqual(val, expected)
945        self.assertEqual(val.nid, 129)
946        self.assertEqual(val.shortname, 'serverAuth')
947        self.assertEqual(val.longname, 'TLS Web Server Authentication')
948        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
949        self.assertIsInstance(val, ssl._ASN1Object)
950        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
951
952        val = ssl._ASN1Object.fromnid(129)
953        self.assertEqual(val, expected)
954        self.assertIsInstance(val, ssl._ASN1Object)
955        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
956        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
957            ssl._ASN1Object.fromnid(100000)
958        for i in range(1000):
959            try:
960                obj = ssl._ASN1Object.fromnid(i)
961            except ValueError:
962                pass
963            else:
964                self.assertIsInstance(obj.nid, int)
965                self.assertIsInstance(obj.shortname, str)
966                self.assertIsInstance(obj.longname, str)
967                self.assertIsInstance(obj.oid, (str, type(None)))
968
969        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
970        self.assertEqual(val, expected)
971        self.assertIsInstance(val, ssl._ASN1Object)
972        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
973        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
974                         expected)
975        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
976            ssl._ASN1Object.fromname('serverauth')
977
978    def test_purpose_enum(self):
979        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
980        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
981        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
982        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
983        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
984        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
985                              '1.3.6.1.5.5.7.3.1')
986
987        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
988        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
989        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
990        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
991        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
992        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
993                              '1.3.6.1.5.5.7.3.2')
994
995    def test_unsupported_dtls(self):
996        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
997        self.addCleanup(s.close)
998        with self.assertRaises(NotImplementedError) as cx:
999            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
1000        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1001        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1002        with self.assertRaises(NotImplementedError) as cx:
1003            ctx.wrap_socket(s)
1004        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1005
1006    def cert_time_ok(self, timestring, timestamp):
1007        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
1008
1009    def cert_time_fail(self, timestring):
1010        with self.assertRaises(ValueError):
1011            ssl.cert_time_to_seconds(timestring)
1012
1013    @unittest.skipUnless(utc_offset(),
1014                         'local time needs to be different from UTC')
1015    def test_cert_time_to_seconds_timezone(self):
1016        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
1017        #               results if local timezone is not UTC
1018        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
1019        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
1020
1021    def test_cert_time_to_seconds(self):
1022        timestring = "Jan  5 09:34:43 2018 GMT"
1023        ts = 1515144883.0
1024        self.cert_time_ok(timestring, ts)
1025        # accept keyword parameter, assert its name
1026        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
1027        # accept both %e and %d (space or zero generated by strftime)
1028        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
1029        # case-insensitive
1030        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
1031        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
1032        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
1033        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
1034        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
1035        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
1036        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
1037        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
1038
1039        newyear_ts = 1230768000.0
1040        # leap seconds
1041        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
1042        # same timestamp
1043        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
1044
1045        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
1046        #  allow 60th second (even if it is not a leap second)
1047        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
1048        #  allow 2nd leap second for compatibility with time.strptime()
1049        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
1050        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
1051
1052        # no special treatment for the special value:
1053        #   99991231235959Z (rfc 5280)
1054        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
1055
1056    @support.run_with_locale('LC_ALL', '')
1057    def test_cert_time_to_seconds_locale(self):
1058        # `cert_time_to_seconds()` should be locale independent
1059
1060        def local_february_name():
1061            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
1062
1063        if local_february_name().lower() == 'feb':
1064            self.skipTest("locale-specific month name needs to be "
1065                          "different from C locale")
1066
1067        # locale-independent
1068        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
1069        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
1070
1071    def test_connect_ex_error(self):
1072        server = socket.socket(socket.AF_INET)
1073        self.addCleanup(server.close)
1074        port = support.bind_port(server)  # Reserve port but don't listen
1075        s = test_wrap_socket(socket.socket(socket.AF_INET),
1076                            cert_reqs=ssl.CERT_REQUIRED)
1077        self.addCleanup(s.close)
1078        rc = s.connect_ex((HOST, port))
1079        # Issue #19919: Windows machines or VMs hosted on Windows
1080        # machines sometimes return EWOULDBLOCK.
1081        errors = (
1082            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1083            errno.EWOULDBLOCK,
1084        )
1085        self.assertIn(rc, errors)
1086
1087
1088class ContextTests(unittest.TestCase):
1089
1090    def test_constructor(self):
1091        for protocol in PROTOCOLS:
1092            ssl.SSLContext(protocol)
1093        ctx = ssl.SSLContext()
1094        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1095        self.assertRaises(ValueError, ssl.SSLContext, -1)
1096        self.assertRaises(ValueError, ssl.SSLContext, 42)
1097
1098    def test_protocol(self):
1099        for proto in PROTOCOLS:
1100            ctx = ssl.SSLContext(proto)
1101            self.assertEqual(ctx.protocol, proto)
1102
1103    def test_ciphers(self):
1104        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1105        ctx.set_ciphers("ALL")
1106        ctx.set_ciphers("DEFAULT")
1107        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1108            ctx.set_ciphers("^$:,;?*'dorothyx")
1109
1110    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1111                         "Test applies only to Python default ciphers")
1112    def test_python_ciphers(self):
1113        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1114        ciphers = ctx.get_ciphers()
1115        for suite in ciphers:
1116            name = suite['name']
1117            self.assertNotIn("PSK", name)
1118            self.assertNotIn("SRP", name)
1119            self.assertNotIn("MD5", name)
1120            self.assertNotIn("RC4", name)
1121            self.assertNotIn("3DES", name)
1122
1123    @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old')
1124    def test_get_ciphers(self):
1125        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1126        ctx.set_ciphers('AESGCM')
1127        names = set(d['name'] for d in ctx.get_ciphers())
1128        self.assertIn('AES256-GCM-SHA384', names)
1129        self.assertIn('AES128-GCM-SHA256', names)
1130
1131    def test_options(self):
1132        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1133        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1134        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1135        # SSLContext also enables these by default
1136        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1137                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1138                    OP_ENABLE_MIDDLEBOX_COMPAT)
1139        self.assertEqual(default, ctx.options)
1140        ctx.options |= ssl.OP_NO_TLSv1
1141        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1142        if can_clear_options():
1143            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1144            self.assertEqual(default, ctx.options)
1145            ctx.options = 0
1146            # Ubuntu has OP_NO_SSLv3 forced on by default
1147            self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1148        else:
1149            with self.assertRaises(ValueError):
1150                ctx.options = 0
1151
1152    def test_verify_mode_protocol(self):
1153        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1154        # Default value
1155        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1156        ctx.verify_mode = ssl.CERT_OPTIONAL
1157        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1158        ctx.verify_mode = ssl.CERT_REQUIRED
1159        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1160        ctx.verify_mode = ssl.CERT_NONE
1161        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1162        with self.assertRaises(TypeError):
1163            ctx.verify_mode = None
1164        with self.assertRaises(ValueError):
1165            ctx.verify_mode = 42
1166
1167        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1168        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1169        self.assertFalse(ctx.check_hostname)
1170
1171        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1172        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1173        self.assertTrue(ctx.check_hostname)
1174
1175    def test_hostname_checks_common_name(self):
1176        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1177        self.assertTrue(ctx.hostname_checks_common_name)
1178        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1179            ctx.hostname_checks_common_name = True
1180            self.assertTrue(ctx.hostname_checks_common_name)
1181            ctx.hostname_checks_common_name = False
1182            self.assertFalse(ctx.hostname_checks_common_name)
1183            ctx.hostname_checks_common_name = True
1184            self.assertTrue(ctx.hostname_checks_common_name)
1185        else:
1186            with self.assertRaises(AttributeError):
1187                ctx.hostname_checks_common_name = True
1188
1189    @requires_minimum_version
1190    @unittest.skipIf(IS_LIBRESSL, "see bpo-34001")
1191    def test_min_max_version(self):
1192        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1193        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1194        # Fedora override the setting to TLS 1.0.
1195        minimum_range = {
1196            # stock OpenSSL
1197            ssl.TLSVersion.MINIMUM_SUPPORTED,
1198            # Fedora 29 uses TLS 1.0 by default
1199            ssl.TLSVersion.TLSv1,
1200            # RHEL 8 uses TLS 1.2 by default
1201            ssl.TLSVersion.TLSv1_2
1202        }
1203        maximum_range = {
1204            # stock OpenSSL
1205            ssl.TLSVersion.MAXIMUM_SUPPORTED,
1206            # Fedora 32 uses TLS 1.3 by default
1207            ssl.TLSVersion.TLSv1_3
1208        }
1209
1210        self.assertIn(
1211            ctx.minimum_version, minimum_range
1212        )
1213        self.assertIn(
1214            ctx.maximum_version, maximum_range
1215        )
1216
1217        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1218        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1219        self.assertEqual(
1220            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1221        )
1222        self.assertEqual(
1223            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1224        )
1225
1226        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1227        ctx.maximum_version = ssl.TLSVersion.TLSv1
1228        self.assertEqual(
1229            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1230        )
1231        self.assertEqual(
1232            ctx.maximum_version, ssl.TLSVersion.TLSv1
1233        )
1234
1235        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1236        self.assertEqual(
1237            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1238        )
1239
1240        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1241        self.assertIn(
1242            ctx.maximum_version,
1243            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3}
1244        )
1245
1246        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1247        self.assertIn(
1248            ctx.minimum_version,
1249            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1250        )
1251
1252        with self.assertRaises(ValueError):
1253            ctx.minimum_version = 42
1254
1255        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1256
1257        self.assertIn(
1258            ctx.minimum_version, minimum_range
1259        )
1260        self.assertEqual(
1261            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1262        )
1263        with self.assertRaises(ValueError):
1264            ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1265        with self.assertRaises(ValueError):
1266            ctx.maximum_version = ssl.TLSVersion.TLSv1
1267
1268
1269    @unittest.skipUnless(have_verify_flags(),
1270                         "verify_flags need OpenSSL > 0.9.8")
1271    def test_verify_flags(self):
1272        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1273        # default value
1274        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1275        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1276        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1277        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1278        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1279        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1280        ctx.verify_flags = ssl.VERIFY_DEFAULT
1281        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1282        # supports any value
1283        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1284        self.assertEqual(ctx.verify_flags,
1285                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1286        with self.assertRaises(TypeError):
1287            ctx.verify_flags = None
1288
1289    def test_load_cert_chain(self):
1290        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1291        # Combined key and cert in a single file
1292        ctx.load_cert_chain(CERTFILE, keyfile=None)
1293        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1294        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1295        with self.assertRaises(OSError) as cm:
1296            ctx.load_cert_chain(NONEXISTINGCERT)
1297        self.assertEqual(cm.exception.errno, errno.ENOENT)
1298        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1299            ctx.load_cert_chain(BADCERT)
1300        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1301            ctx.load_cert_chain(EMPTYCERT)
1302        # Separate key and cert
1303        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1304        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1305        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1306        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1307        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1308            ctx.load_cert_chain(ONLYCERT)
1309        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1310            ctx.load_cert_chain(ONLYKEY)
1311        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1312            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1313        # Mismatching key and cert
1314        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1315        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1316            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1317        # Password protected key and cert
1318        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1319        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1320        ctx.load_cert_chain(CERTFILE_PROTECTED,
1321                            password=bytearray(KEY_PASSWORD.encode()))
1322        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1323        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1324        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1325                            bytearray(KEY_PASSWORD.encode()))
1326        with self.assertRaisesRegex(TypeError, "should be a string"):
1327            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1328        with self.assertRaises(ssl.SSLError):
1329            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1330        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1331            # openssl has a fixed limit on the password buffer.
1332            # PEM_BUFSIZE is generally set to 1kb.
1333            # Return a string larger than this.
1334            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1335        # Password callback
1336        def getpass_unicode():
1337            return KEY_PASSWORD
1338        def getpass_bytes():
1339            return KEY_PASSWORD.encode()
1340        def getpass_bytearray():
1341            return bytearray(KEY_PASSWORD.encode())
1342        def getpass_badpass():
1343            return "badpass"
1344        def getpass_huge():
1345            return b'a' * (1024 * 1024)
1346        def getpass_bad_type():
1347            return 9
1348        def getpass_exception():
1349            raise Exception('getpass error')
1350        class GetPassCallable:
1351            def __call__(self):
1352                return KEY_PASSWORD
1353            def getpass(self):
1354                return KEY_PASSWORD
1355        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1356        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1357        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1358        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1359        ctx.load_cert_chain(CERTFILE_PROTECTED,
1360                            password=GetPassCallable().getpass)
1361        with self.assertRaises(ssl.SSLError):
1362            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1363        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1364            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1365        with self.assertRaisesRegex(TypeError, "must return a string"):
1366            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1367        with self.assertRaisesRegex(Exception, "getpass error"):
1368            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1369        # Make sure the password function isn't called if it isn't needed
1370        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1371
1372    def test_load_verify_locations(self):
1373        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1374        ctx.load_verify_locations(CERTFILE)
1375        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1376        ctx.load_verify_locations(BYTES_CERTFILE)
1377        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1378        self.assertRaises(TypeError, ctx.load_verify_locations)
1379        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1380        with self.assertRaises(OSError) as cm:
1381            ctx.load_verify_locations(NONEXISTINGCERT)
1382        self.assertEqual(cm.exception.errno, errno.ENOENT)
1383        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1384            ctx.load_verify_locations(BADCERT)
1385        ctx.load_verify_locations(CERTFILE, CAPATH)
1386        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1387
1388        # Issue #10989: crash if the second argument type is invalid
1389        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1390
1391    def test_load_verify_cadata(self):
1392        # test cadata
1393        with open(CAFILE_CACERT) as f:
1394            cacert_pem = f.read()
1395        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1396        with open(CAFILE_NEURONIO) as f:
1397            neuronio_pem = f.read()
1398        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1399
1400        # test PEM
1401        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1402        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1403        ctx.load_verify_locations(cadata=cacert_pem)
1404        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1405        ctx.load_verify_locations(cadata=neuronio_pem)
1406        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1407        # cert already in hash table
1408        ctx.load_verify_locations(cadata=neuronio_pem)
1409        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1410
1411        # combined
1412        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1413        combined = "\n".join((cacert_pem, neuronio_pem))
1414        ctx.load_verify_locations(cadata=combined)
1415        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1416
1417        # with junk around the certs
1418        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1419        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1420                    neuronio_pem, "tail"]
1421        ctx.load_verify_locations(cadata="\n".join(combined))
1422        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1423
1424        # test DER
1425        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1426        ctx.load_verify_locations(cadata=cacert_der)
1427        ctx.load_verify_locations(cadata=neuronio_der)
1428        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1429        # cert already in hash table
1430        ctx.load_verify_locations(cadata=cacert_der)
1431        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1432
1433        # combined
1434        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1435        combined = b"".join((cacert_der, neuronio_der))
1436        ctx.load_verify_locations(cadata=combined)
1437        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1438
1439        # error cases
1440        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1441        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1442
1443        with self.assertRaisesRegex(ssl.SSLError, "no start line"):
1444            ctx.load_verify_locations(cadata="broken")
1445        with self.assertRaisesRegex(ssl.SSLError, "not enough data"):
1446            ctx.load_verify_locations(cadata=b"broken")
1447
1448
1449    def test_load_dh_params(self):
1450        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1451        ctx.load_dh_params(DHFILE)
1452        if os.name != 'nt':
1453            ctx.load_dh_params(BYTES_DHFILE)
1454        self.assertRaises(TypeError, ctx.load_dh_params)
1455        self.assertRaises(TypeError, ctx.load_dh_params, None)
1456        with self.assertRaises(FileNotFoundError) as cm:
1457            ctx.load_dh_params(NONEXISTINGCERT)
1458        self.assertEqual(cm.exception.errno, errno.ENOENT)
1459        with self.assertRaises(ssl.SSLError) as cm:
1460            ctx.load_dh_params(CERTFILE)
1461
1462    def test_session_stats(self):
1463        for proto in PROTOCOLS:
1464            ctx = ssl.SSLContext(proto)
1465            self.assertEqual(ctx.session_stats(), {
1466                'number': 0,
1467                'connect': 0,
1468                'connect_good': 0,
1469                'connect_renegotiate': 0,
1470                'accept': 0,
1471                'accept_good': 0,
1472                'accept_renegotiate': 0,
1473                'hits': 0,
1474                'misses': 0,
1475                'timeouts': 0,
1476                'cache_full': 0,
1477            })
1478
1479    def test_set_default_verify_paths(self):
1480        # There's not much we can do to test that it acts as expected,
1481        # so just check it doesn't crash or raise an exception.
1482        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1483        ctx.set_default_verify_paths()
1484
1485    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1486    def test_set_ecdh_curve(self):
1487        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1488        ctx.set_ecdh_curve("prime256v1")
1489        ctx.set_ecdh_curve(b"prime256v1")
1490        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1491        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1492        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1493        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1494
1495    @needs_sni
1496    def test_sni_callback(self):
1497        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1498
1499        # set_servername_callback expects a callable, or None
1500        self.assertRaises(TypeError, ctx.set_servername_callback)
1501        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1502        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1503        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1504
1505        def dummycallback(sock, servername, ctx):
1506            pass
1507        ctx.set_servername_callback(None)
1508        ctx.set_servername_callback(dummycallback)
1509
1510    @needs_sni
1511    def test_sni_callback_refcycle(self):
1512        # Reference cycles through the servername callback are detected
1513        # and cleared.
1514        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1515        def dummycallback(sock, servername, ctx, cycle=ctx):
1516            pass
1517        ctx.set_servername_callback(dummycallback)
1518        wr = weakref.ref(ctx)
1519        del ctx, dummycallback
1520        gc.collect()
1521        self.assertIs(wr(), None)
1522
1523    def test_cert_store_stats(self):
1524        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1525        self.assertEqual(ctx.cert_store_stats(),
1526            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1527        ctx.load_cert_chain(CERTFILE)
1528        self.assertEqual(ctx.cert_store_stats(),
1529            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1530        ctx.load_verify_locations(CERTFILE)
1531        self.assertEqual(ctx.cert_store_stats(),
1532            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1533        ctx.load_verify_locations(CAFILE_CACERT)
1534        self.assertEqual(ctx.cert_store_stats(),
1535            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1536
1537    def test_get_ca_certs(self):
1538        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1539        self.assertEqual(ctx.get_ca_certs(), [])
1540        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1541        ctx.load_verify_locations(CERTFILE)
1542        self.assertEqual(ctx.get_ca_certs(), [])
1543        # but CAFILE_CACERT is a CA cert
1544        ctx.load_verify_locations(CAFILE_CACERT)
1545        self.assertEqual(ctx.get_ca_certs(),
1546            [{'issuer': ((('organizationName', 'Root CA'),),
1547                         (('organizationalUnitName', 'http://www.cacert.org'),),
1548                         (('commonName', 'CA Cert Signing Authority'),),
1549                         (('emailAddress', 'support@cacert.org'),)),
1550              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
1551              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
1552              'serialNumber': '00',
1553              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1554              'subject': ((('organizationName', 'Root CA'),),
1555                          (('organizationalUnitName', 'http://www.cacert.org'),),
1556                          (('commonName', 'CA Cert Signing Authority'),),
1557                          (('emailAddress', 'support@cacert.org'),)),
1558              'version': 3}])
1559
1560        with open(CAFILE_CACERT) as f:
1561            pem = f.read()
1562        der = ssl.PEM_cert_to_DER_cert(pem)
1563        self.assertEqual(ctx.get_ca_certs(True), [der])
1564
1565    def test_load_default_certs(self):
1566        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1567        ctx.load_default_certs()
1568
1569        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1570        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1571        ctx.load_default_certs()
1572
1573        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1574        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1575
1576        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1577        self.assertRaises(TypeError, ctx.load_default_certs, None)
1578        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1579
1580    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1581    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
1582    def test_load_default_certs_env(self):
1583        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1584        with support.EnvironmentVarGuard() as env:
1585            env["SSL_CERT_DIR"] = CAPATH
1586            env["SSL_CERT_FILE"] = CERTFILE
1587            ctx.load_default_certs()
1588            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1589
1590    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1591    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1592    def test_load_default_certs_env_windows(self):
1593        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1594        ctx.load_default_certs()
1595        stats = ctx.cert_store_stats()
1596
1597        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1598        with support.EnvironmentVarGuard() as env:
1599            env["SSL_CERT_DIR"] = CAPATH
1600            env["SSL_CERT_FILE"] = CERTFILE
1601            ctx.load_default_certs()
1602            stats["x509"] += 1
1603            self.assertEqual(ctx.cert_store_stats(), stats)
1604
1605    def _assert_context_options(self, ctx):
1606        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1607        if OP_NO_COMPRESSION != 0:
1608            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1609                             OP_NO_COMPRESSION)
1610        if OP_SINGLE_DH_USE != 0:
1611            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1612                             OP_SINGLE_DH_USE)
1613        if OP_SINGLE_ECDH_USE != 0:
1614            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1615                             OP_SINGLE_ECDH_USE)
1616        if OP_CIPHER_SERVER_PREFERENCE != 0:
1617            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1618                             OP_CIPHER_SERVER_PREFERENCE)
1619
1620    def test_create_default_context(self):
1621        ctx = ssl.create_default_context()
1622
1623        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1624        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1625        self.assertTrue(ctx.check_hostname)
1626        self._assert_context_options(ctx)
1627
1628        with open(SIGNING_CA) as f:
1629            cadata = f.read()
1630        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1631                                         cadata=cadata)
1632        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1633        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1634        self._assert_context_options(ctx)
1635
1636        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1637        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1638        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1639        self._assert_context_options(ctx)
1640
1641    def test__create_stdlib_context(self):
1642        ctx = ssl._create_stdlib_context()
1643        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1644        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1645        self.assertFalse(ctx.check_hostname)
1646        self._assert_context_options(ctx)
1647
1648        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1649        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1650        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1651        self._assert_context_options(ctx)
1652
1653        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
1654                                         cert_reqs=ssl.CERT_REQUIRED,
1655                                         check_hostname=True)
1656        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1657        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1658        self.assertTrue(ctx.check_hostname)
1659        self._assert_context_options(ctx)
1660
1661        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1662        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1663        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1664        self._assert_context_options(ctx)
1665
1666    def test_check_hostname(self):
1667        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1668        self.assertFalse(ctx.check_hostname)
1669        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1670
1671        # Auto set CERT_REQUIRED
1672        ctx.check_hostname = True
1673        self.assertTrue(ctx.check_hostname)
1674        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1675        ctx.check_hostname = False
1676        ctx.verify_mode = ssl.CERT_REQUIRED
1677        self.assertFalse(ctx.check_hostname)
1678        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1679
1680        # Changing verify_mode does not affect check_hostname
1681        ctx.check_hostname = False
1682        ctx.verify_mode = ssl.CERT_NONE
1683        ctx.check_hostname = False
1684        self.assertFalse(ctx.check_hostname)
1685        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1686        # Auto set
1687        ctx.check_hostname = True
1688        self.assertTrue(ctx.check_hostname)
1689        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1690
1691        ctx.check_hostname = False
1692        ctx.verify_mode = ssl.CERT_OPTIONAL
1693        ctx.check_hostname = False
1694        self.assertFalse(ctx.check_hostname)
1695        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1696        # keep CERT_OPTIONAL
1697        ctx.check_hostname = True
1698        self.assertTrue(ctx.check_hostname)
1699        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1700
1701        # Cannot set CERT_NONE with check_hostname enabled
1702        with self.assertRaises(ValueError):
1703            ctx.verify_mode = ssl.CERT_NONE
1704        ctx.check_hostname = False
1705        self.assertFalse(ctx.check_hostname)
1706        ctx.verify_mode = ssl.CERT_NONE
1707        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1708
1709    def test_context_client_server(self):
1710        # PROTOCOL_TLS_CLIENT has sane defaults
1711        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1712        self.assertTrue(ctx.check_hostname)
1713        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1714
1715        # PROTOCOL_TLS_SERVER has different but also sane defaults
1716        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1717        self.assertFalse(ctx.check_hostname)
1718        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1719
1720    def test_context_custom_class(self):
1721        class MySSLSocket(ssl.SSLSocket):
1722            pass
1723
1724        class MySSLObject(ssl.SSLObject):
1725            pass
1726
1727        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1728        ctx.sslsocket_class = MySSLSocket
1729        ctx.sslobject_class = MySSLObject
1730
1731        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1732            self.assertIsInstance(sock, MySSLSocket)
1733        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO())
1734        self.assertIsInstance(obj, MySSLObject)
1735
1736    @unittest.skipUnless(IS_OPENSSL_1_1_1, "Test requires OpenSSL 1.1.1")
1737    def test_num_tickest(self):
1738        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1739        self.assertEqual(ctx.num_tickets, 2)
1740        ctx.num_tickets = 1
1741        self.assertEqual(ctx.num_tickets, 1)
1742        ctx.num_tickets = 0
1743        self.assertEqual(ctx.num_tickets, 0)
1744        with self.assertRaises(ValueError):
1745            ctx.num_tickets = -1
1746        with self.assertRaises(TypeError):
1747            ctx.num_tickets = None
1748
1749        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1750        self.assertEqual(ctx.num_tickets, 2)
1751        with self.assertRaises(ValueError):
1752            ctx.num_tickets = 1
1753
1754
1755class SSLErrorTests(unittest.TestCase):
1756
1757    def test_str(self):
1758        # The str() of a SSLError doesn't include the errno
1759        e = ssl.SSLError(1, "foo")
1760        self.assertEqual(str(e), "foo")
1761        self.assertEqual(e.errno, 1)
1762        # Same for a subclass
1763        e = ssl.SSLZeroReturnError(1, "foo")
1764        self.assertEqual(str(e), "foo")
1765        self.assertEqual(e.errno, 1)
1766
1767    def test_lib_reason(self):
1768        # Test the library and reason attributes
1769        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1770        with self.assertRaises(ssl.SSLError) as cm:
1771            ctx.load_dh_params(CERTFILE)
1772        self.assertEqual(cm.exception.library, 'PEM')
1773        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1774        s = str(cm.exception)
1775        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1776
1777    def test_subclass(self):
1778        # Check that the appropriate SSLError subclass is raised
1779        # (this only tests one of them)
1780        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1781        ctx.check_hostname = False
1782        ctx.verify_mode = ssl.CERT_NONE
1783        with socket.create_server(("127.0.0.1", 0)) as s:
1784            c = socket.create_connection(s.getsockname())
1785            c.setblocking(False)
1786            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1787                with self.assertRaises(ssl.SSLWantReadError) as cm:
1788                    c.do_handshake()
1789                s = str(cm.exception)
1790                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1791                # For compatibility
1792                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1793
1794
1795    def test_bad_server_hostname(self):
1796        ctx = ssl.create_default_context()
1797        with self.assertRaises(ValueError):
1798            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1799                         server_hostname="")
1800        with self.assertRaises(ValueError):
1801            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1802                         server_hostname=".example.org")
1803        with self.assertRaises(TypeError):
1804            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1805                         server_hostname="example.org\x00evil.com")
1806
1807
1808class MemoryBIOTests(unittest.TestCase):
1809
1810    def test_read_write(self):
1811        bio = ssl.MemoryBIO()
1812        bio.write(b'foo')
1813        self.assertEqual(bio.read(), b'foo')
1814        self.assertEqual(bio.read(), b'')
1815        bio.write(b'foo')
1816        bio.write(b'bar')
1817        self.assertEqual(bio.read(), b'foobar')
1818        self.assertEqual(bio.read(), b'')
1819        bio.write(b'baz')
1820        self.assertEqual(bio.read(2), b'ba')
1821        self.assertEqual(bio.read(1), b'z')
1822        self.assertEqual(bio.read(1), b'')
1823
1824    def test_eof(self):
1825        bio = ssl.MemoryBIO()
1826        self.assertFalse(bio.eof)
1827        self.assertEqual(bio.read(), b'')
1828        self.assertFalse(bio.eof)
1829        bio.write(b'foo')
1830        self.assertFalse(bio.eof)
1831        bio.write_eof()
1832        self.assertFalse(bio.eof)
1833        self.assertEqual(bio.read(2), b'fo')
1834        self.assertFalse(bio.eof)
1835        self.assertEqual(bio.read(1), b'o')
1836        self.assertTrue(bio.eof)
1837        self.assertEqual(bio.read(), b'')
1838        self.assertTrue(bio.eof)
1839
1840    def test_pending(self):
1841        bio = ssl.MemoryBIO()
1842        self.assertEqual(bio.pending, 0)
1843        bio.write(b'foo')
1844        self.assertEqual(bio.pending, 3)
1845        for i in range(3):
1846            bio.read(1)
1847            self.assertEqual(bio.pending, 3-i-1)
1848        for i in range(3):
1849            bio.write(b'x')
1850            self.assertEqual(bio.pending, i+1)
1851        bio.read()
1852        self.assertEqual(bio.pending, 0)
1853
1854    def test_buffer_types(self):
1855        bio = ssl.MemoryBIO()
1856        bio.write(b'foo')
1857        self.assertEqual(bio.read(), b'foo')
1858        bio.write(bytearray(b'bar'))
1859        self.assertEqual(bio.read(), b'bar')
1860        bio.write(memoryview(b'baz'))
1861        self.assertEqual(bio.read(), b'baz')
1862
1863    def test_error_types(self):
1864        bio = ssl.MemoryBIO()
1865        self.assertRaises(TypeError, bio.write, 'foo')
1866        self.assertRaises(TypeError, bio.write, None)
1867        self.assertRaises(TypeError, bio.write, True)
1868        self.assertRaises(TypeError, bio.write, 1)
1869
1870
1871class SSLObjectTests(unittest.TestCase):
1872    def test_private_init(self):
1873        bio = ssl.MemoryBIO()
1874        with self.assertRaisesRegex(TypeError, "public constructor"):
1875            ssl.SSLObject(bio, bio)
1876
1877    def test_unwrap(self):
1878        client_ctx, server_ctx, hostname = testing_context()
1879        c_in = ssl.MemoryBIO()
1880        c_out = ssl.MemoryBIO()
1881        s_in = ssl.MemoryBIO()
1882        s_out = ssl.MemoryBIO()
1883        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1884        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1885
1886        # Loop on the handshake for a bit to get it settled
1887        for _ in range(5):
1888            try:
1889                client.do_handshake()
1890            except ssl.SSLWantReadError:
1891                pass
1892            if c_out.pending:
1893                s_in.write(c_out.read())
1894            try:
1895                server.do_handshake()
1896            except ssl.SSLWantReadError:
1897                pass
1898            if s_out.pending:
1899                c_in.write(s_out.read())
1900        # Now the handshakes should be complete (don't raise WantReadError)
1901        client.do_handshake()
1902        server.do_handshake()
1903
1904        # Now if we unwrap one side unilaterally, it should send close-notify
1905        # and raise WantReadError:
1906        with self.assertRaises(ssl.SSLWantReadError):
1907            client.unwrap()
1908
1909        # But server.unwrap() does not raise, because it reads the client's
1910        # close-notify:
1911        s_in.write(c_out.read())
1912        server.unwrap()
1913
1914        # And now that the client gets the server's close-notify, it doesn't
1915        # raise either.
1916        c_in.write(s_out.read())
1917        client.unwrap()
1918
1919class SimpleBackgroundTests(unittest.TestCase):
1920    """Tests that connect to a simple server running in the background"""
1921
1922    def setUp(self):
1923        server = ThreadedEchoServer(SIGNED_CERTFILE)
1924        self.server_addr = (HOST, server.port)
1925        server.__enter__()
1926        self.addCleanup(server.__exit__, None, None, None)
1927
1928    def test_connect(self):
1929        with test_wrap_socket(socket.socket(socket.AF_INET),
1930                            cert_reqs=ssl.CERT_NONE) as s:
1931            s.connect(self.server_addr)
1932            self.assertEqual({}, s.getpeercert())
1933            self.assertFalse(s.server_side)
1934
1935        # this should succeed because we specify the root cert
1936        with test_wrap_socket(socket.socket(socket.AF_INET),
1937                            cert_reqs=ssl.CERT_REQUIRED,
1938                            ca_certs=SIGNING_CA) as s:
1939            s.connect(self.server_addr)
1940            self.assertTrue(s.getpeercert())
1941            self.assertFalse(s.server_side)
1942
1943    def test_connect_fail(self):
1944        # This should fail because we have no verification certs. Connection
1945        # failure crashes ThreadedEchoServer, so run this in an independent
1946        # test method.
1947        s = test_wrap_socket(socket.socket(socket.AF_INET),
1948                            cert_reqs=ssl.CERT_REQUIRED)
1949        self.addCleanup(s.close)
1950        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1951                               s.connect, self.server_addr)
1952
1953    def test_connect_ex(self):
1954        # Issue #11326: check connect_ex() implementation
1955        s = test_wrap_socket(socket.socket(socket.AF_INET),
1956                            cert_reqs=ssl.CERT_REQUIRED,
1957                            ca_certs=SIGNING_CA)
1958        self.addCleanup(s.close)
1959        self.assertEqual(0, s.connect_ex(self.server_addr))
1960        self.assertTrue(s.getpeercert())
1961
1962    def test_non_blocking_connect_ex(self):
1963        # Issue #11326: non-blocking connect_ex() should allow handshake
1964        # to proceed after the socket gets ready.
1965        s = test_wrap_socket(socket.socket(socket.AF_INET),
1966                            cert_reqs=ssl.CERT_REQUIRED,
1967                            ca_certs=SIGNING_CA,
1968                            do_handshake_on_connect=False)
1969        self.addCleanup(s.close)
1970        s.setblocking(False)
1971        rc = s.connect_ex(self.server_addr)
1972        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
1973        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
1974        # Wait for connect to finish
1975        select.select([], [s], [], 5.0)
1976        # Non-blocking handshake
1977        while True:
1978            try:
1979                s.do_handshake()
1980                break
1981            except ssl.SSLWantReadError:
1982                select.select([s], [], [], 5.0)
1983            except ssl.SSLWantWriteError:
1984                select.select([], [s], [], 5.0)
1985        # SSL established
1986        self.assertTrue(s.getpeercert())
1987
1988    def test_connect_with_context(self):
1989        # Same as test_connect, but with a separately created context
1990        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1991        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
1992            s.connect(self.server_addr)
1993            self.assertEqual({}, s.getpeercert())
1994        # Same with a server hostname
1995        with ctx.wrap_socket(socket.socket(socket.AF_INET),
1996                            server_hostname="dummy") as s:
1997            s.connect(self.server_addr)
1998        ctx.verify_mode = ssl.CERT_REQUIRED
1999        # This should succeed because we specify the root cert
2000        ctx.load_verify_locations(SIGNING_CA)
2001        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2002            s.connect(self.server_addr)
2003            cert = s.getpeercert()
2004            self.assertTrue(cert)
2005
2006    def test_connect_with_context_fail(self):
2007        # This should fail because we have no verification certs. Connection
2008        # failure crashes ThreadedEchoServer, so run this in an independent
2009        # test method.
2010        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2011        ctx.verify_mode = ssl.CERT_REQUIRED
2012        s = ctx.wrap_socket(socket.socket(socket.AF_INET))
2013        self.addCleanup(s.close)
2014        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2015                                s.connect, self.server_addr)
2016
2017    def test_connect_capath(self):
2018        # Verify server certificates using the `capath` argument
2019        # NOTE: the subject hashing algorithm has been changed between
2020        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
2021        # contain both versions of each certificate (same content, different
2022        # filename) for this test to be portable across OpenSSL releases.
2023        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2024        ctx.verify_mode = ssl.CERT_REQUIRED
2025        ctx.load_verify_locations(capath=CAPATH)
2026        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2027            s.connect(self.server_addr)
2028            cert = s.getpeercert()
2029            self.assertTrue(cert)
2030
2031        # Same with a bytes `capath` argument
2032        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2033        ctx.verify_mode = ssl.CERT_REQUIRED
2034        ctx.load_verify_locations(capath=BYTES_CAPATH)
2035        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2036            s.connect(self.server_addr)
2037            cert = s.getpeercert()
2038            self.assertTrue(cert)
2039
2040    def test_connect_cadata(self):
2041        with open(SIGNING_CA) as f:
2042            pem = f.read()
2043        der = ssl.PEM_cert_to_DER_cert(pem)
2044        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2045        ctx.verify_mode = ssl.CERT_REQUIRED
2046        ctx.load_verify_locations(cadata=pem)
2047        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2048            s.connect(self.server_addr)
2049            cert = s.getpeercert()
2050            self.assertTrue(cert)
2051
2052        # same with DER
2053        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2054        ctx.verify_mode = ssl.CERT_REQUIRED
2055        ctx.load_verify_locations(cadata=der)
2056        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2057            s.connect(self.server_addr)
2058            cert = s.getpeercert()
2059            self.assertTrue(cert)
2060
2061    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
2062    def test_makefile_close(self):
2063        # Issue #5238: creating a file-like object with makefile() shouldn't
2064        # delay closing the underlying "real socket" (here tested with its
2065        # file descriptor, hence skipping the test under Windows).
2066        ss = test_wrap_socket(socket.socket(socket.AF_INET))
2067        ss.connect(self.server_addr)
2068        fd = ss.fileno()
2069        f = ss.makefile()
2070        f.close()
2071        # The fd is still open
2072        os.read(fd, 0)
2073        # Closing the SSL socket should close the fd too
2074        ss.close()
2075        gc.collect()
2076        with self.assertRaises(OSError) as e:
2077            os.read(fd, 0)
2078        self.assertEqual(e.exception.errno, errno.EBADF)
2079
2080    def test_non_blocking_handshake(self):
2081        s = socket.socket(socket.AF_INET)
2082        s.connect(self.server_addr)
2083        s.setblocking(False)
2084        s = test_wrap_socket(s,
2085                            cert_reqs=ssl.CERT_NONE,
2086                            do_handshake_on_connect=False)
2087        self.addCleanup(s.close)
2088        count = 0
2089        while True:
2090            try:
2091                count += 1
2092                s.do_handshake()
2093                break
2094            except ssl.SSLWantReadError:
2095                select.select([s], [], [])
2096            except ssl.SSLWantWriteError:
2097                select.select([], [s], [])
2098        if support.verbose:
2099            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2100
2101    def test_get_server_certificate(self):
2102        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2103
2104    def test_get_server_certificate_fail(self):
2105        # Connection failure crashes ThreadedEchoServer, so run this in an
2106        # independent test method
2107        _test_get_server_certificate_fail(self, *self.server_addr)
2108
2109    def test_ciphers(self):
2110        with test_wrap_socket(socket.socket(socket.AF_INET),
2111                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2112            s.connect(self.server_addr)
2113        with test_wrap_socket(socket.socket(socket.AF_INET),
2114                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2115            s.connect(self.server_addr)
2116        # Error checking can happen at instantiation or when connecting
2117        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2118            with socket.socket(socket.AF_INET) as sock:
2119                s = test_wrap_socket(sock,
2120                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2121                s.connect(self.server_addr)
2122
2123    def test_get_ca_certs_capath(self):
2124        # capath certs are loaded on request
2125        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2126        ctx.load_verify_locations(capath=CAPATH)
2127        self.assertEqual(ctx.get_ca_certs(), [])
2128        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2129                             server_hostname='localhost') as s:
2130            s.connect(self.server_addr)
2131            cert = s.getpeercert()
2132            self.assertTrue(cert)
2133        self.assertEqual(len(ctx.get_ca_certs()), 1)
2134
2135    @needs_sni
2136    def test_context_setget(self):
2137        # Check that the context of a connected socket can be replaced.
2138        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2139        ctx1.load_verify_locations(capath=CAPATH)
2140        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2141        ctx2.load_verify_locations(capath=CAPATH)
2142        s = socket.socket(socket.AF_INET)
2143        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2144            ss.connect(self.server_addr)
2145            self.assertIs(ss.context, ctx1)
2146            self.assertIs(ss._sslobj.context, ctx1)
2147            ss.context = ctx2
2148            self.assertIs(ss.context, ctx2)
2149            self.assertIs(ss._sslobj.context, ctx2)
2150
2151    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2152        # A simple IO loop. Call func(*args) depending on the error we get
2153        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2154        timeout = kwargs.get('timeout', 10)
2155        deadline = time.monotonic() + timeout
2156        count = 0
2157        while True:
2158            if time.monotonic() > deadline:
2159                self.fail("timeout")
2160            errno = None
2161            count += 1
2162            try:
2163                ret = func(*args)
2164            except ssl.SSLError as e:
2165                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2166                                   ssl.SSL_ERROR_WANT_WRITE):
2167                    raise
2168                errno = e.errno
2169            # Get any data from the outgoing BIO irrespective of any error, and
2170            # send it to the socket.
2171            buf = outgoing.read()
2172            sock.sendall(buf)
2173            # If there's no error, we're done. For WANT_READ, we need to get
2174            # data from the socket and put it in the incoming BIO.
2175            if errno is None:
2176                break
2177            elif errno == ssl.SSL_ERROR_WANT_READ:
2178                buf = sock.recv(32768)
2179                if buf:
2180                    incoming.write(buf)
2181                else:
2182                    incoming.write_eof()
2183        if support.verbose:
2184            sys.stdout.write("Needed %d calls to complete %s().\n"
2185                             % (count, func.__name__))
2186        return ret
2187
2188    def test_bio_handshake(self):
2189        sock = socket.socket(socket.AF_INET)
2190        self.addCleanup(sock.close)
2191        sock.connect(self.server_addr)
2192        incoming = ssl.MemoryBIO()
2193        outgoing = ssl.MemoryBIO()
2194        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2195        self.assertTrue(ctx.check_hostname)
2196        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2197        ctx.load_verify_locations(SIGNING_CA)
2198        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2199                              SIGNED_CERTFILE_HOSTNAME)
2200        self.assertIs(sslobj._sslobj.owner, sslobj)
2201        self.assertIsNone(sslobj.cipher())
2202        self.assertIsNone(sslobj.version())
2203        self.assertIsNotNone(sslobj.shared_ciphers())
2204        self.assertRaises(ValueError, sslobj.getpeercert)
2205        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2206            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2207        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2208        self.assertTrue(sslobj.cipher())
2209        self.assertIsNotNone(sslobj.shared_ciphers())
2210        self.assertIsNotNone(sslobj.version())
2211        self.assertTrue(sslobj.getpeercert())
2212        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2213            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2214        try:
2215            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2216        except ssl.SSLSyscallError:
2217            # If the server shuts down the TCP connection without sending a
2218            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2219            pass
2220        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2221
2222    def test_bio_read_write_data(self):
2223        sock = socket.socket(socket.AF_INET)
2224        self.addCleanup(sock.close)
2225        sock.connect(self.server_addr)
2226        incoming = ssl.MemoryBIO()
2227        outgoing = ssl.MemoryBIO()
2228        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2229        ctx.verify_mode = ssl.CERT_NONE
2230        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2231        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2232        req = b'FOO\n'
2233        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2234        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2235        self.assertEqual(buf, b'foo\n')
2236        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2237
2238
2239class NetworkedTests(unittest.TestCase):
2240
2241    def test_timeout_connect_ex(self):
2242        # Issue #12065: on a timeout, connect_ex() should return the original
2243        # errno (mimicking the behaviour of non-SSL sockets).
2244        with support.transient_internet(REMOTE_HOST):
2245            s = test_wrap_socket(socket.socket(socket.AF_INET),
2246                                cert_reqs=ssl.CERT_REQUIRED,
2247                                do_handshake_on_connect=False)
2248            self.addCleanup(s.close)
2249            s.settimeout(0.0000001)
2250            rc = s.connect_ex((REMOTE_HOST, 443))
2251            if rc == 0:
2252                self.skipTest("REMOTE_HOST responded too quickly")
2253            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2254
2255    @unittest.skipUnless(support.IPV6_ENABLED, 'Needs IPv6')
2256    def test_get_server_certificate_ipv6(self):
2257        with support.transient_internet('ipv6.google.com'):
2258            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2259            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2260
2261
2262def _test_get_server_certificate(test, host, port, cert=None):
2263    pem = ssl.get_server_certificate((host, port))
2264    if not pem:
2265        test.fail("No server certificate on %s:%s!" % (host, port))
2266
2267    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2268    if not pem:
2269        test.fail("No server certificate on %s:%s!" % (host, port))
2270    if support.verbose:
2271        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2272
2273def _test_get_server_certificate_fail(test, host, port):
2274    try:
2275        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2276    except ssl.SSLError as x:
2277        #should fail
2278        if support.verbose:
2279            sys.stdout.write("%s\n" % x)
2280    else:
2281        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2282
2283
2284from test.ssl_servers import make_https_server
2285
2286class ThreadedEchoServer(threading.Thread):
2287
2288    class ConnectionHandler(threading.Thread):
2289
2290        """A mildly complicated class, because we want it to work both
2291        with and without the SSL wrapper around the socket connection, so
2292        that we can test the STARTTLS functionality."""
2293
2294        def __init__(self, server, connsock, addr):
2295            self.server = server
2296            self.running = False
2297            self.sock = connsock
2298            self.addr = addr
2299            self.sock.setblocking(1)
2300            self.sslconn = None
2301            threading.Thread.__init__(self)
2302            self.daemon = True
2303
2304        def wrap_conn(self):
2305            try:
2306                self.sslconn = self.server.context.wrap_socket(
2307                    self.sock, server_side=True)
2308                self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
2309                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2310            except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError) as e:
2311                # We treat ConnectionResetError as though it were an
2312                # SSLError - OpenSSL on Ubuntu abruptly closes the
2313                # connection when asked to use an unsupported protocol.
2314                #
2315                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2316                # tries to send session tickets after handshake.
2317                # https://github.com/openssl/openssl/issues/6342
2318                #
2319                # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL
2320                # tries to send session tickets after handshake when using WinSock.
2321                self.server.conn_errors.append(str(e))
2322                if self.server.chatty:
2323                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2324                self.running = False
2325                self.close()
2326                return False
2327            except (ssl.SSLError, OSError) as e:
2328                # OSError may occur with wrong protocols, e.g. both
2329                # sides use PROTOCOL_TLS_SERVER.
2330                #
2331                # XXX Various errors can have happened here, for example
2332                # a mismatching protocol version, an invalid certificate,
2333                # or a low-level bug. This should be made more discriminating.
2334                #
2335                # bpo-31323: Store the exception as string to prevent
2336                # a reference leak: server -> conn_errors -> exception
2337                # -> traceback -> self (ConnectionHandler) -> server
2338                self.server.conn_errors.append(str(e))
2339                if self.server.chatty:
2340                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2341                self.running = False
2342                self.server.stop()
2343                self.close()
2344                return False
2345            else:
2346                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2347                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2348                    cert = self.sslconn.getpeercert()
2349                    if support.verbose and self.server.chatty:
2350                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2351                    cert_binary = self.sslconn.getpeercert(True)
2352                    if support.verbose and self.server.chatty:
2353                        sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
2354                cipher = self.sslconn.cipher()
2355                if support.verbose and self.server.chatty:
2356                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2357                    sys.stdout.write(" server: selected protocol is now "
2358                            + str(self.sslconn.selected_npn_protocol()) + "\n")
2359                return True
2360
2361        def read(self):
2362            if self.sslconn:
2363                return self.sslconn.read()
2364            else:
2365                return self.sock.recv(1024)
2366
2367        def write(self, bytes):
2368            if self.sslconn:
2369                return self.sslconn.write(bytes)
2370            else:
2371                return self.sock.send(bytes)
2372
2373        def close(self):
2374            if self.sslconn:
2375                self.sslconn.close()
2376            else:
2377                self.sock.close()
2378
2379        def run(self):
2380            self.running = True
2381            if not self.server.starttls_server:
2382                if not self.wrap_conn():
2383                    return
2384            while self.running:
2385                try:
2386                    msg = self.read()
2387                    stripped = msg.strip()
2388                    if not stripped:
2389                        # eof, so quit this handler
2390                        self.running = False
2391                        try:
2392                            self.sock = self.sslconn.unwrap()
2393                        except OSError:
2394                            # Many tests shut the TCP connection down
2395                            # without an SSL shutdown. This causes
2396                            # unwrap() to raise OSError with errno=0!
2397                            pass
2398                        else:
2399                            self.sslconn = None
2400                        self.close()
2401                    elif stripped == b'over':
2402                        if support.verbose and self.server.connectionchatty:
2403                            sys.stdout.write(" server: client closed connection\n")
2404                        self.close()
2405                        return
2406                    elif (self.server.starttls_server and
2407                          stripped == b'STARTTLS'):
2408                        if support.verbose and self.server.connectionchatty:
2409                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2410                        self.write(b"OK\n")
2411                        if not self.wrap_conn():
2412                            return
2413                    elif (self.server.starttls_server and self.sslconn
2414                          and stripped == b'ENDTLS'):
2415                        if support.verbose and self.server.connectionchatty:
2416                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2417                        self.write(b"OK\n")
2418                        self.sock = self.sslconn.unwrap()
2419                        self.sslconn = None
2420                        if support.verbose and self.server.connectionchatty:
2421                            sys.stdout.write(" server: connection is now unencrypted...\n")
2422                    elif stripped == b'CB tls-unique':
2423                        if support.verbose and self.server.connectionchatty:
2424                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2425                        data = self.sslconn.get_channel_binding("tls-unique")
2426                        self.write(repr(data).encode("us-ascii") + b"\n")
2427                    elif stripped == b'PHA':
2428                        if support.verbose and self.server.connectionchatty:
2429                            sys.stdout.write(" server: initiating post handshake auth\n")
2430                        try:
2431                            self.sslconn.verify_client_post_handshake()
2432                        except ssl.SSLError as e:
2433                            self.write(repr(e).encode("us-ascii") + b"\n")
2434                        else:
2435                            self.write(b"OK\n")
2436                    elif stripped == b'HASCERT':
2437                        if self.sslconn.getpeercert() is not None:
2438                            self.write(b'TRUE\n')
2439                        else:
2440                            self.write(b'FALSE\n')
2441                    elif stripped == b'GETCERT':
2442                        cert = self.sslconn.getpeercert()
2443                        self.write(repr(cert).encode("us-ascii") + b"\n")
2444                    else:
2445                        if (support.verbose and
2446                            self.server.connectionchatty):
2447                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2448                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2449                                             % (msg, ctype, msg.lower(), ctype))
2450                        self.write(msg.lower())
2451                except (ConnectionResetError, ConnectionAbortedError):
2452                    # XXX: OpenSSL 1.1.1 sometimes raises ConnectionResetError
2453                    # when connection is not shut down gracefully.
2454                    if self.server.chatty and support.verbose:
2455                        sys.stdout.write(
2456                            " Connection reset by peer: {}\n".format(
2457                                self.addr)
2458                        )
2459                    self.close()
2460                    self.running = False
2461                except ssl.SSLError as err:
2462                    # On Windows sometimes test_pha_required_nocert receives the
2463                    # PEER_DID_NOT_RETURN_A_CERTIFICATE exception
2464                    # before the 'tlsv13 alert certificate required' exception.
2465                    # If the server is stopped when PEER_DID_NOT_RETURN_A_CERTIFICATE
2466                    # is received test_pha_required_nocert fails with ConnectionResetError
2467                    # because the underlying socket is closed
2468                    if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' == err.reason:
2469                        if self.server.chatty and support.verbose:
2470                            sys.stdout.write(err.args[1])
2471                        # test_pha_required_nocert is expecting this exception
2472                        raise ssl.SSLError('tlsv13 alert certificate required')
2473                except OSError:
2474                    if self.server.chatty:
2475                        handle_error("Test server failure:\n")
2476                    self.close()
2477                    self.running = False
2478
2479                    # normally, we'd just stop here, but for the test
2480                    # harness, we want to stop the server
2481                    self.server.stop()
2482
2483    def __init__(self, certificate=None, ssl_version=None,
2484                 certreqs=None, cacerts=None,
2485                 chatty=True, connectionchatty=False, starttls_server=False,
2486                 npn_protocols=None, alpn_protocols=None,
2487                 ciphers=None, context=None):
2488        if context:
2489            self.context = context
2490        else:
2491            self.context = ssl.SSLContext(ssl_version
2492                                          if ssl_version is not None
2493                                          else ssl.PROTOCOL_TLS_SERVER)
2494            self.context.verify_mode = (certreqs if certreqs is not None
2495                                        else ssl.CERT_NONE)
2496            if cacerts:
2497                self.context.load_verify_locations(cacerts)
2498            if certificate:
2499                self.context.load_cert_chain(certificate)
2500            if npn_protocols:
2501                self.context.set_npn_protocols(npn_protocols)
2502            if alpn_protocols:
2503                self.context.set_alpn_protocols(alpn_protocols)
2504            if ciphers:
2505                self.context.set_ciphers(ciphers)
2506        self.chatty = chatty
2507        self.connectionchatty = connectionchatty
2508        self.starttls_server = starttls_server
2509        self.sock = socket.socket()
2510        self.port = support.bind_port(self.sock)
2511        self.flag = None
2512        self.active = False
2513        self.selected_npn_protocols = []
2514        self.selected_alpn_protocols = []
2515        self.shared_ciphers = []
2516        self.conn_errors = []
2517        threading.Thread.__init__(self)
2518        self.daemon = True
2519
2520    def __enter__(self):
2521        self.start(threading.Event())
2522        self.flag.wait()
2523        return self
2524
2525    def __exit__(self, *args):
2526        self.stop()
2527        self.join()
2528
2529    def start(self, flag=None):
2530        self.flag = flag
2531        threading.Thread.start(self)
2532
2533    def run(self):
2534        self.sock.settimeout(0.05)
2535        self.sock.listen()
2536        self.active = True
2537        if self.flag:
2538            # signal an event
2539            self.flag.set()
2540        while self.active:
2541            try:
2542                newconn, connaddr = self.sock.accept()
2543                if support.verbose and self.chatty:
2544                    sys.stdout.write(' server:  new connection from '
2545                                     + repr(connaddr) + '\n')
2546                handler = self.ConnectionHandler(self, newconn, connaddr)
2547                handler.start()
2548                handler.join()
2549            except socket.timeout:
2550                pass
2551            except KeyboardInterrupt:
2552                self.stop()
2553            except BaseException as e:
2554                if support.verbose and self.chatty:
2555                    sys.stdout.write(
2556                        ' connection handling failed: ' + repr(e) + '\n')
2557
2558        self.sock.close()
2559
2560    def stop(self):
2561        self.active = False
2562
2563class AsyncoreEchoServer(threading.Thread):
2564
2565    # this one's based on asyncore.dispatcher
2566
2567    class EchoServer (asyncore.dispatcher):
2568
2569        class ConnectionHandler(asyncore.dispatcher_with_send):
2570
2571            def __init__(self, conn, certfile):
2572                self.socket = test_wrap_socket(conn, server_side=True,
2573                                              certfile=certfile,
2574                                              do_handshake_on_connect=False)
2575                asyncore.dispatcher_with_send.__init__(self, self.socket)
2576                self._ssl_accepting = True
2577                self._do_ssl_handshake()
2578
2579            def readable(self):
2580                if isinstance(self.socket, ssl.SSLSocket):
2581                    while self.socket.pending() > 0:
2582                        self.handle_read_event()
2583                return True
2584
2585            def _do_ssl_handshake(self):
2586                try:
2587                    self.socket.do_handshake()
2588                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2589                    return
2590                except ssl.SSLEOFError:
2591                    return self.handle_close()
2592                except ssl.SSLError:
2593                    raise
2594                except OSError as err:
2595                    if err.args[0] == errno.ECONNABORTED:
2596                        return self.handle_close()
2597                else:
2598                    self._ssl_accepting = False
2599
2600            def handle_read(self):
2601                if self._ssl_accepting:
2602                    self._do_ssl_handshake()
2603                else:
2604                    data = self.recv(1024)
2605                    if support.verbose:
2606                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2607                    if not data:
2608                        self.close()
2609                    else:
2610                        self.send(data.lower())
2611
2612            def handle_close(self):
2613                self.close()
2614                if support.verbose:
2615                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2616
2617            def handle_error(self):
2618                raise
2619
2620        def __init__(self, certfile):
2621            self.certfile = certfile
2622            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2623            self.port = support.bind_port(sock, '')
2624            asyncore.dispatcher.__init__(self, sock)
2625            self.listen(5)
2626
2627        def handle_accepted(self, sock_obj, addr):
2628            if support.verbose:
2629                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2630            self.ConnectionHandler(sock_obj, self.certfile)
2631
2632        def handle_error(self):
2633            raise
2634
2635    def __init__(self, certfile):
2636        self.flag = None
2637        self.active = False
2638        self.server = self.EchoServer(certfile)
2639        self.port = self.server.port
2640        threading.Thread.__init__(self)
2641        self.daemon = True
2642
2643    def __str__(self):
2644        return "<%s %s>" % (self.__class__.__name__, self.server)
2645
2646    def __enter__(self):
2647        self.start(threading.Event())
2648        self.flag.wait()
2649        return self
2650
2651    def __exit__(self, *args):
2652        if support.verbose:
2653            sys.stdout.write(" cleanup: stopping server.\n")
2654        self.stop()
2655        if support.verbose:
2656            sys.stdout.write(" cleanup: joining server thread.\n")
2657        self.join()
2658        if support.verbose:
2659            sys.stdout.write(" cleanup: successfully joined.\n")
2660        # make sure that ConnectionHandler is removed from socket_map
2661        asyncore.close_all(ignore_all=True)
2662
2663    def start (self, flag=None):
2664        self.flag = flag
2665        threading.Thread.start(self)
2666
2667    def run(self):
2668        self.active = True
2669        if self.flag:
2670            self.flag.set()
2671        while self.active:
2672            try:
2673                asyncore.loop(1)
2674            except:
2675                pass
2676
2677    def stop(self):
2678        self.active = False
2679        self.server.close()
2680
2681def server_params_test(client_context, server_context, indata=b"FOO\n",
2682                       chatty=True, connectionchatty=False, sni_name=None,
2683                       session=None):
2684    """
2685    Launch a server, connect a client to it and try various reads
2686    and writes.
2687    """
2688    stats = {}
2689    server = ThreadedEchoServer(context=server_context,
2690                                chatty=chatty,
2691                                connectionchatty=False)
2692    with server:
2693        with client_context.wrap_socket(socket.socket(),
2694                server_hostname=sni_name, session=session) as s:
2695            s.connect((HOST, server.port))
2696            for arg in [indata, bytearray(indata), memoryview(indata)]:
2697                if connectionchatty:
2698                    if support.verbose:
2699                        sys.stdout.write(
2700                            " client:  sending %r...\n" % indata)
2701                s.write(arg)
2702                outdata = s.read()
2703                if connectionchatty:
2704                    if support.verbose:
2705                        sys.stdout.write(" client:  read %r\n" % outdata)
2706                if outdata != indata.lower():
2707                    raise AssertionError(
2708                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2709                        % (outdata[:20], len(outdata),
2710                           indata[:20].lower(), len(indata)))
2711            s.write(b"over\n")
2712            if connectionchatty:
2713                if support.verbose:
2714                    sys.stdout.write(" client:  closing connection.\n")
2715            stats.update({
2716                'compression': s.compression(),
2717                'cipher': s.cipher(),
2718                'peercert': s.getpeercert(),
2719                'client_alpn_protocol': s.selected_alpn_protocol(),
2720                'client_npn_protocol': s.selected_npn_protocol(),
2721                'version': s.version(),
2722                'session_reused': s.session_reused,
2723                'session': s.session,
2724            })
2725            s.close()
2726        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2727        stats['server_npn_protocols'] = server.selected_npn_protocols
2728        stats['server_shared_ciphers'] = server.shared_ciphers
2729    return stats
2730
2731def try_protocol_combo(server_protocol, client_protocol, expect_success,
2732                       certsreqs=None, server_options=0, client_options=0):
2733    """
2734    Try to SSL-connect using *client_protocol* to *server_protocol*.
2735    If *expect_success* is true, assert that the connection succeeds,
2736    if it's false, assert that the connection fails.
2737    Also, if *expect_success* is a string, assert that it is the protocol
2738    version actually used by the connection.
2739    """
2740    if certsreqs is None:
2741        certsreqs = ssl.CERT_NONE
2742    certtype = {
2743        ssl.CERT_NONE: "CERT_NONE",
2744        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2745        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2746    }[certsreqs]
2747    if support.verbose:
2748        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2749        sys.stdout.write(formatstr %
2750                         (ssl.get_protocol_name(client_protocol),
2751                          ssl.get_protocol_name(server_protocol),
2752                          certtype))
2753    client_context = ssl.SSLContext(client_protocol)
2754    client_context.options |= client_options
2755    server_context = ssl.SSLContext(server_protocol)
2756    server_context.options |= server_options
2757
2758    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2759    if (min_version is not None
2760    # SSLContext.minimum_version is only available on recent OpenSSL
2761    # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2762    and hasattr(server_context, 'minimum_version')
2763    and server_protocol == ssl.PROTOCOL_TLS
2764    and server_context.minimum_version > min_version):
2765        # If OpenSSL configuration is strict and requires more recent TLS
2766        # version, we have to change the minimum to test old TLS versions.
2767        server_context.minimum_version = min_version
2768
2769    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2770    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2771    # starting from OpenSSL 1.0.0 (see issue #8322).
2772    if client_context.protocol == ssl.PROTOCOL_TLS:
2773        client_context.set_ciphers("ALL")
2774
2775    for ctx in (client_context, server_context):
2776        ctx.verify_mode = certsreqs
2777        ctx.load_cert_chain(SIGNED_CERTFILE)
2778        ctx.load_verify_locations(SIGNING_CA)
2779    try:
2780        stats = server_params_test(client_context, server_context,
2781                                   chatty=False, connectionchatty=False)
2782    # Protocol mismatch can result in either an SSLError, or a
2783    # "Connection reset by peer" error.
2784    except ssl.SSLError:
2785        if expect_success:
2786            raise
2787    except OSError as e:
2788        if expect_success or e.errno != errno.ECONNRESET:
2789            raise
2790    else:
2791        if not expect_success:
2792            raise AssertionError(
2793                "Client protocol %s succeeded with server protocol %s!"
2794                % (ssl.get_protocol_name(client_protocol),
2795                   ssl.get_protocol_name(server_protocol)))
2796        elif (expect_success is not True
2797              and expect_success != stats['version']):
2798            raise AssertionError("version mismatch: expected %r, got %r"
2799                                 % (expect_success, stats['version']))
2800
2801
2802class ThreadedTests(unittest.TestCase):
2803
2804    def test_echo(self):
2805        """Basic test of an SSL client connecting to a server"""
2806        if support.verbose:
2807            sys.stdout.write("\n")
2808        for protocol in PROTOCOLS:
2809            if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
2810                continue
2811            if not has_tls_protocol(protocol):
2812                continue
2813            with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
2814                context = ssl.SSLContext(protocol)
2815                context.load_cert_chain(CERTFILE)
2816                server_params_test(context, context,
2817                                   chatty=True, connectionchatty=True)
2818
2819        client_context, server_context, hostname = testing_context()
2820
2821        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2822            server_params_test(client_context=client_context,
2823                               server_context=server_context,
2824                               chatty=True, connectionchatty=True,
2825                               sni_name=hostname)
2826
2827        client_context.check_hostname = False
2828        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2829            with self.assertRaises(ssl.SSLError) as e:
2830                server_params_test(client_context=server_context,
2831                                   server_context=client_context,
2832                                   chatty=True, connectionchatty=True,
2833                                   sni_name=hostname)
2834            self.assertIn('called a function you should not call',
2835                          str(e.exception))
2836
2837        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2838            with self.assertRaises(ssl.SSLError) as e:
2839                server_params_test(client_context=server_context,
2840                                   server_context=server_context,
2841                                   chatty=True, connectionchatty=True)
2842            self.assertIn('called a function you should not call',
2843                          str(e.exception))
2844
2845        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2846            with self.assertRaises(ssl.SSLError) as e:
2847                server_params_test(client_context=server_context,
2848                                   server_context=client_context,
2849                                   chatty=True, connectionchatty=True)
2850            self.assertIn('called a function you should not call',
2851                          str(e.exception))
2852
2853    def test_getpeercert(self):
2854        if support.verbose:
2855            sys.stdout.write("\n")
2856
2857        client_context, server_context, hostname = testing_context()
2858        server = ThreadedEchoServer(context=server_context, chatty=False)
2859        with server:
2860            with client_context.wrap_socket(socket.socket(),
2861                                            do_handshake_on_connect=False,
2862                                            server_hostname=hostname) as s:
2863                s.connect((HOST, server.port))
2864                # getpeercert() raise ValueError while the handshake isn't
2865                # done.
2866                with self.assertRaises(ValueError):
2867                    s.getpeercert()
2868                s.do_handshake()
2869                cert = s.getpeercert()
2870                self.assertTrue(cert, "Can't get peer certificate.")
2871                cipher = s.cipher()
2872                if support.verbose:
2873                    sys.stdout.write(pprint.pformat(cert) + '\n')
2874                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2875                if 'subject' not in cert:
2876                    self.fail("No subject field in certificate: %s." %
2877                              pprint.pformat(cert))
2878                if ((('organizationName', 'Python Software Foundation'),)
2879                    not in cert['subject']):
2880                    self.fail(
2881                        "Missing or invalid 'organizationName' field in certificate subject; "
2882                        "should be 'Python Software Foundation'.")
2883                self.assertIn('notBefore', cert)
2884                self.assertIn('notAfter', cert)
2885                before = ssl.cert_time_to_seconds(cert['notBefore'])
2886                after = ssl.cert_time_to_seconds(cert['notAfter'])
2887                self.assertLess(before, after)
2888
2889    @unittest.skipUnless(have_verify_flags(),
2890                        "verify_flags need OpenSSL > 0.9.8")
2891    def test_crl_check(self):
2892        if support.verbose:
2893            sys.stdout.write("\n")
2894
2895        client_context, server_context, hostname = testing_context()
2896
2897        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
2898        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
2899
2900        # VERIFY_DEFAULT should pass
2901        server = ThreadedEchoServer(context=server_context, chatty=True)
2902        with server:
2903            with client_context.wrap_socket(socket.socket(),
2904                                            server_hostname=hostname) as s:
2905                s.connect((HOST, server.port))
2906                cert = s.getpeercert()
2907                self.assertTrue(cert, "Can't get peer certificate.")
2908
2909        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
2910        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
2911
2912        server = ThreadedEchoServer(context=server_context, chatty=True)
2913        with server:
2914            with client_context.wrap_socket(socket.socket(),
2915                                            server_hostname=hostname) as s:
2916                with self.assertRaisesRegex(ssl.SSLError,
2917                                            "certificate verify failed"):
2918                    s.connect((HOST, server.port))
2919
2920        # now load a CRL file. The CRL file is signed by the CA.
2921        client_context.load_verify_locations(CRLFILE)
2922
2923        server = ThreadedEchoServer(context=server_context, chatty=True)
2924        with server:
2925            with client_context.wrap_socket(socket.socket(),
2926                                            server_hostname=hostname) as s:
2927                s.connect((HOST, server.port))
2928                cert = s.getpeercert()
2929                self.assertTrue(cert, "Can't get peer certificate.")
2930
2931    def test_check_hostname(self):
2932        if support.verbose:
2933            sys.stdout.write("\n")
2934
2935        client_context, server_context, hostname = testing_context()
2936
2937        # correct hostname should verify
2938        server = ThreadedEchoServer(context=server_context, chatty=True)
2939        with server:
2940            with client_context.wrap_socket(socket.socket(),
2941                                            server_hostname=hostname) as s:
2942                s.connect((HOST, server.port))
2943                cert = s.getpeercert()
2944                self.assertTrue(cert, "Can't get peer certificate.")
2945
2946        # incorrect hostname should raise an exception
2947        server = ThreadedEchoServer(context=server_context, chatty=True)
2948        with server:
2949            with client_context.wrap_socket(socket.socket(),
2950                                            server_hostname="invalid") as s:
2951                with self.assertRaisesRegex(
2952                        ssl.CertificateError,
2953                        "Hostname mismatch, certificate is not valid for 'invalid'."):
2954                    s.connect((HOST, server.port))
2955
2956        # missing server_hostname arg should cause an exception, too
2957        server = ThreadedEchoServer(context=server_context, chatty=True)
2958        with server:
2959            with socket.socket() as s:
2960                with self.assertRaisesRegex(ValueError,
2961                                            "check_hostname requires server_hostname"):
2962                    client_context.wrap_socket(s)
2963
2964    def test_ecc_cert(self):
2965        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2966        client_context.load_verify_locations(SIGNING_CA)
2967        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
2968        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
2969
2970        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2971        # load ECC cert
2972        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
2973
2974        # correct hostname should verify
2975        server = ThreadedEchoServer(context=server_context, chatty=True)
2976        with server:
2977            with client_context.wrap_socket(socket.socket(),
2978                                            server_hostname=hostname) as s:
2979                s.connect((HOST, server.port))
2980                cert = s.getpeercert()
2981                self.assertTrue(cert, "Can't get peer certificate.")
2982                cipher = s.cipher()[0].split('-')
2983                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
2984
2985    def test_dual_rsa_ecc(self):
2986        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2987        client_context.load_verify_locations(SIGNING_CA)
2988        # TODO: fix TLSv1.3 once SSLContext can restrict signature
2989        #       algorithms.
2990        client_context.options |= ssl.OP_NO_TLSv1_3
2991        # only ECDSA certs
2992        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
2993        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
2994
2995        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
2996        # load ECC and RSA key/cert pairs
2997        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
2998        server_context.load_cert_chain(SIGNED_CERTFILE)
2999
3000        # correct hostname should verify
3001        server = ThreadedEchoServer(context=server_context, chatty=True)
3002        with server:
3003            with client_context.wrap_socket(socket.socket(),
3004                                            server_hostname=hostname) as s:
3005                s.connect((HOST, server.port))
3006                cert = s.getpeercert()
3007                self.assertTrue(cert, "Can't get peer certificate.")
3008                cipher = s.cipher()[0].split('-')
3009                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3010
3011    def test_check_hostname_idn(self):
3012        if support.verbose:
3013            sys.stdout.write("\n")
3014
3015        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3016        server_context.load_cert_chain(IDNSANSFILE)
3017
3018        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3019        context.verify_mode = ssl.CERT_REQUIRED
3020        context.check_hostname = True
3021        context.load_verify_locations(SIGNING_CA)
3022
3023        # correct hostname should verify, when specified in several
3024        # different ways
3025        idn_hostnames = [
3026            ('könig.idn.pythontest.net',
3027             'xn--knig-5qa.idn.pythontest.net'),
3028            ('xn--knig-5qa.idn.pythontest.net',
3029             'xn--knig-5qa.idn.pythontest.net'),
3030            (b'xn--knig-5qa.idn.pythontest.net',
3031             'xn--knig-5qa.idn.pythontest.net'),
3032
3033            ('königsgäßchen.idna2003.pythontest.net',
3034             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3035            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3036             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3037            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3038             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3039
3040            # ('königsgäßchen.idna2008.pythontest.net',
3041            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3042            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3043             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3044            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3045             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3046
3047        ]
3048        for server_hostname, expected_hostname in idn_hostnames:
3049            server = ThreadedEchoServer(context=server_context, chatty=True)
3050            with server:
3051                with context.wrap_socket(socket.socket(),
3052                                         server_hostname=server_hostname) as s:
3053                    self.assertEqual(s.server_hostname, expected_hostname)
3054                    s.connect((HOST, server.port))
3055                    cert = s.getpeercert()
3056                    self.assertEqual(s.server_hostname, expected_hostname)
3057                    self.assertTrue(cert, "Can't get peer certificate.")
3058
3059        # incorrect hostname should raise an exception
3060        server = ThreadedEchoServer(context=server_context, chatty=True)
3061        with server:
3062            with context.wrap_socket(socket.socket(),
3063                                     server_hostname="python.example.org") as s:
3064                with self.assertRaises(ssl.CertificateError):
3065                    s.connect((HOST, server.port))
3066
3067    def test_wrong_cert_tls12(self):
3068        """Connecting when the server rejects the client's certificate
3069
3070        Launch a server with CERT_REQUIRED, and check that trying to
3071        connect to it with a wrong client certificate fails.
3072        """
3073        client_context, server_context, hostname = testing_context()
3074        # load client cert that is not signed by trusted CA
3075        client_context.load_cert_chain(CERTFILE)
3076        # require TLS client authentication
3077        server_context.verify_mode = ssl.CERT_REQUIRED
3078        # TLS 1.3 has different handshake
3079        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3080
3081        server = ThreadedEchoServer(
3082            context=server_context, chatty=True, connectionchatty=True,
3083        )
3084
3085        with server, \
3086                client_context.wrap_socket(socket.socket(),
3087                                           server_hostname=hostname) as s:
3088            try:
3089                # Expect either an SSL error about the server rejecting
3090                # the connection, or a low-level connection reset (which
3091                # sometimes happens on Windows)
3092                s.connect((HOST, server.port))
3093            except ssl.SSLError as e:
3094                if support.verbose:
3095                    sys.stdout.write("\nSSLError is %r\n" % e)
3096            except OSError as e:
3097                if e.errno != errno.ECONNRESET:
3098                    raise
3099                if support.verbose:
3100                    sys.stdout.write("\nsocket.error is %r\n" % e)
3101            else:
3102                self.fail("Use of invalid cert should have failed!")
3103
3104    @requires_tls_version('TLSv1_3')
3105    def test_wrong_cert_tls13(self):
3106        client_context, server_context, hostname = testing_context()
3107        # load client cert that is not signed by trusted CA
3108        client_context.load_cert_chain(CERTFILE)
3109        server_context.verify_mode = ssl.CERT_REQUIRED
3110        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3111        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3112
3113        server = ThreadedEchoServer(
3114            context=server_context, chatty=True, connectionchatty=True,
3115        )
3116        with server, \
3117             client_context.wrap_socket(socket.socket(),
3118                                        server_hostname=hostname) as s:
3119            # TLS 1.3 perform client cert exchange after handshake
3120            s.connect((HOST, server.port))
3121            try:
3122                s.write(b'data')
3123                s.read(4)
3124            except ssl.SSLError as e:
3125                if support.verbose:
3126                    sys.stdout.write("\nSSLError is %r\n" % e)
3127            except OSError as e:
3128                if e.errno != errno.ECONNRESET:
3129                    raise
3130                if support.verbose:
3131                    sys.stdout.write("\nsocket.error is %r\n" % e)
3132            else:
3133                self.fail("Use of invalid cert should have failed!")
3134
3135    def test_rude_shutdown(self):
3136        """A brutal shutdown of an SSL server should raise an OSError
3137        in the client when attempting handshake.
3138        """
3139        listener_ready = threading.Event()
3140        listener_gone = threading.Event()
3141
3142        s = socket.socket()
3143        port = support.bind_port(s, HOST)
3144
3145        # `listener` runs in a thread.  It sits in an accept() until
3146        # the main thread connects.  Then it rudely closes the socket,
3147        # and sets Event `listener_gone` to let the main thread know
3148        # the socket is gone.
3149        def listener():
3150            s.listen()
3151            listener_ready.set()
3152            newsock, addr = s.accept()
3153            newsock.close()
3154            s.close()
3155            listener_gone.set()
3156
3157        def connector():
3158            listener_ready.wait()
3159            with socket.socket() as c:
3160                c.connect((HOST, port))
3161                listener_gone.wait()
3162                try:
3163                    ssl_sock = test_wrap_socket(c)
3164                except OSError:
3165                    pass
3166                else:
3167                    self.fail('connecting to closed SSL socket should have failed')
3168
3169        t = threading.Thread(target=listener)
3170        t.start()
3171        try:
3172            connector()
3173        finally:
3174            t.join()
3175
3176    def test_ssl_cert_verify_error(self):
3177        if support.verbose:
3178            sys.stdout.write("\n")
3179
3180        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3181        server_context.load_cert_chain(SIGNED_CERTFILE)
3182
3183        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3184
3185        server = ThreadedEchoServer(context=server_context, chatty=True)
3186        with server:
3187            with context.wrap_socket(socket.socket(),
3188                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3189                try:
3190                    s.connect((HOST, server.port))
3191                except ssl.SSLError as e:
3192                    msg = 'unable to get local issuer certificate'
3193                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3194                    self.assertEqual(e.verify_code, 20)
3195                    self.assertEqual(e.verify_message, msg)
3196                    self.assertIn(msg, repr(e))
3197                    self.assertIn('certificate verify failed', repr(e))
3198
3199    @requires_tls_version('SSLv2')
3200    def test_protocol_sslv2(self):
3201        """Connecting to an SSLv2 server with various client options"""
3202        if support.verbose:
3203            sys.stdout.write("\n")
3204        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3205        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3206        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3207        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3208        if has_tls_version('SSLv3'):
3209            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3210        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3211        # SSLv23 client with specific SSL options
3212        if no_sslv2_implies_sslv3_hello():
3213            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3214            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3215                               client_options=ssl.OP_NO_SSLv2)
3216        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3217                           client_options=ssl.OP_NO_SSLv3)
3218        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3219                           client_options=ssl.OP_NO_TLSv1)
3220
3221    def test_PROTOCOL_TLS(self):
3222        """Connecting to an SSLv23 server with various client options"""
3223        if support.verbose:
3224            sys.stdout.write("\n")
3225        if has_tls_version('SSLv2'):
3226            try:
3227                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3228            except OSError as x:
3229                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3230                if support.verbose:
3231                    sys.stdout.write(
3232                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3233                        % str(x))
3234        if has_tls_version('SSLv3'):
3235            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3236        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3237        if has_tls_version('TLSv1'):
3238            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3239
3240        if has_tls_version('SSLv3'):
3241            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3242        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3243        if has_tls_version('TLSv1'):
3244            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3245
3246        if has_tls_version('SSLv3'):
3247            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3248        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3249        if has_tls_version('TLSv1'):
3250            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3251
3252        # Server with specific SSL options
3253        if has_tls_version('SSLv3'):
3254            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3255                           server_options=ssl.OP_NO_SSLv3)
3256        # Will choose TLSv1
3257        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3258                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3259        if has_tls_version('TLSv1'):
3260            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3261                               server_options=ssl.OP_NO_TLSv1)
3262
3263    @requires_tls_version('SSLv3')
3264    def test_protocol_sslv3(self):
3265        """Connecting to an SSLv3 server with various client options"""
3266        if support.verbose:
3267            sys.stdout.write("\n")
3268        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3269        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3270        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3271        if has_tls_version('SSLv2'):
3272            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3273        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3274                           client_options=ssl.OP_NO_SSLv3)
3275        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3276        if no_sslv2_implies_sslv3_hello():
3277            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3278            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS,
3279                               False, client_options=ssl.OP_NO_SSLv2)
3280
3281    @requires_tls_version('TLSv1')
3282    def test_protocol_tlsv1(self):
3283        """Connecting to a TLSv1 server with various client options"""
3284        if support.verbose:
3285            sys.stdout.write("\n")
3286        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3287        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3288        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3289        if has_tls_version('SSLv2'):
3290            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3291        if has_tls_version('SSLv3'):
3292            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3293        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3294                           client_options=ssl.OP_NO_TLSv1)
3295
3296    @requires_tls_version('TLSv1_1')
3297    def test_protocol_tlsv1_1(self):
3298        """Connecting to a TLSv1.1 server with various client options.
3299           Testing against older TLS versions."""
3300        if support.verbose:
3301            sys.stdout.write("\n")
3302        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3303        if has_tls_version('SSLv2'):
3304            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3305        if has_tls_version('SSLv3'):
3306            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3307        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3308                           client_options=ssl.OP_NO_TLSv1_1)
3309
3310        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3311        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3312        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3313
3314    @requires_tls_version('TLSv1_2')
3315    def test_protocol_tlsv1_2(self):
3316        """Connecting to a TLSv1.2 server with various client options.
3317           Testing against older TLS versions."""
3318        if support.verbose:
3319            sys.stdout.write("\n")
3320        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3321                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3322                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3323        if has_tls_version('SSLv2'):
3324            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3325        if has_tls_version('SSLv3'):
3326            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3327        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3328                           client_options=ssl.OP_NO_TLSv1_2)
3329
3330        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3331        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3332        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3333        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3334        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3335
3336    def test_starttls(self):
3337        """Switching from clear text to encrypted and back again."""
3338        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3339
3340        server = ThreadedEchoServer(CERTFILE,
3341                                    starttls_server=True,
3342                                    chatty=True,
3343                                    connectionchatty=True)
3344        wrapped = False
3345        with server:
3346            s = socket.socket()
3347            s.setblocking(1)
3348            s.connect((HOST, server.port))
3349            if support.verbose:
3350                sys.stdout.write("\n")
3351            for indata in msgs:
3352                if support.verbose:
3353                    sys.stdout.write(
3354                        " client:  sending %r...\n" % indata)
3355                if wrapped:
3356                    conn.write(indata)
3357                    outdata = conn.read()
3358                else:
3359                    s.send(indata)
3360                    outdata = s.recv(1024)
3361                msg = outdata.strip().lower()
3362                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3363                    # STARTTLS ok, switch to secure mode
3364                    if support.verbose:
3365                        sys.stdout.write(
3366                            " client:  read %r from server, starting TLS...\n"
3367                            % msg)
3368                    conn = test_wrap_socket(s)
3369                    wrapped = True
3370                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3371                    # ENDTLS ok, switch back to clear text
3372                    if support.verbose:
3373                        sys.stdout.write(
3374                            " client:  read %r from server, ending TLS...\n"
3375                            % msg)
3376                    s = conn.unwrap()
3377                    wrapped = False
3378                else:
3379                    if support.verbose:
3380                        sys.stdout.write(
3381                            " client:  read %r from server\n" % msg)
3382            if support.verbose:
3383                sys.stdout.write(" client:  closing connection.\n")
3384            if wrapped:
3385                conn.write(b"over\n")
3386            else:
3387                s.send(b"over\n")
3388            if wrapped:
3389                conn.close()
3390            else:
3391                s.close()
3392
3393    def test_socketserver(self):
3394        """Using socketserver to create and manage SSL connections."""
3395        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3396        # try to connect
3397        if support.verbose:
3398            sys.stdout.write('\n')
3399        with open(CERTFILE, 'rb') as f:
3400            d1 = f.read()
3401        d2 = ''
3402        # now fetch the same data from the HTTPS server
3403        url = 'https://localhost:%d/%s' % (
3404            server.port, os.path.split(CERTFILE)[1])
3405        context = ssl.create_default_context(cafile=SIGNING_CA)
3406        f = urllib.request.urlopen(url, context=context)
3407        try:
3408            dlen = f.info().get("content-length")
3409            if dlen and (int(dlen) > 0):
3410                d2 = f.read(int(dlen))
3411                if support.verbose:
3412                    sys.stdout.write(
3413                        " client: read %d bytes from remote server '%s'\n"
3414                        % (len(d2), server))
3415        finally:
3416            f.close()
3417        self.assertEqual(d1, d2)
3418
3419    def test_asyncore_server(self):
3420        """Check the example asyncore integration."""
3421        if support.verbose:
3422            sys.stdout.write("\n")
3423
3424        indata = b"FOO\n"
3425        server = AsyncoreEchoServer(CERTFILE)
3426        with server:
3427            s = test_wrap_socket(socket.socket())
3428            s.connect(('127.0.0.1', server.port))
3429            if support.verbose:
3430                sys.stdout.write(
3431                    " client:  sending %r...\n" % indata)
3432            s.write(indata)
3433            outdata = s.read()
3434            if support.verbose:
3435                sys.stdout.write(" client:  read %r\n" % outdata)
3436            if outdata != indata.lower():
3437                self.fail(
3438                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3439                    % (outdata[:20], len(outdata),
3440                       indata[:20].lower(), len(indata)))
3441            s.write(b"over\n")
3442            if support.verbose:
3443                sys.stdout.write(" client:  closing connection.\n")
3444            s.close()
3445            if support.verbose:
3446                sys.stdout.write(" client:  connection closed.\n")
3447
3448    def test_recv_send(self):
3449        """Test recv(), send() and friends."""
3450        if support.verbose:
3451            sys.stdout.write("\n")
3452
3453        server = ThreadedEchoServer(CERTFILE,
3454                                    certreqs=ssl.CERT_NONE,
3455                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3456                                    cacerts=CERTFILE,
3457                                    chatty=True,
3458                                    connectionchatty=False)
3459        with server:
3460            s = test_wrap_socket(socket.socket(),
3461                                server_side=False,
3462                                certfile=CERTFILE,
3463                                ca_certs=CERTFILE,
3464                                cert_reqs=ssl.CERT_NONE,
3465                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3466            s.connect((HOST, server.port))
3467            # helper methods for standardising recv* method signatures
3468            def _recv_into():
3469                b = bytearray(b"\0"*100)
3470                count = s.recv_into(b)
3471                return b[:count]
3472
3473            def _recvfrom_into():
3474                b = bytearray(b"\0"*100)
3475                count, addr = s.recvfrom_into(b)
3476                return b[:count]
3477
3478            # (name, method, expect success?, *args, return value func)
3479            send_methods = [
3480                ('send', s.send, True, [], len),
3481                ('sendto', s.sendto, False, ["some.address"], len),
3482                ('sendall', s.sendall, True, [], lambda x: None),
3483            ]
3484            # (name, method, whether to expect success, *args)
3485            recv_methods = [
3486                ('recv', s.recv, True, []),
3487                ('recvfrom', s.recvfrom, False, ["some.address"]),
3488                ('recv_into', _recv_into, True, []),
3489                ('recvfrom_into', _recvfrom_into, False, []),
3490            ]
3491            data_prefix = "PREFIX_"
3492
3493            for (meth_name, send_meth, expect_success, args,
3494                    ret_val_meth) in send_methods:
3495                indata = (data_prefix + meth_name).encode('ascii')
3496                try:
3497                    ret = send_meth(indata, *args)
3498                    msg = "sending with {}".format(meth_name)
3499                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3500                    outdata = s.read()
3501                    if outdata != indata.lower():
3502                        self.fail(
3503                            "While sending with <<{name:s}>> bad data "
3504                            "<<{outdata:r}>> ({nout:d}) received; "
3505                            "expected <<{indata:r}>> ({nin:d})\n".format(
3506                                name=meth_name, outdata=outdata[:20],
3507                                nout=len(outdata),
3508                                indata=indata[:20], nin=len(indata)
3509                            )
3510                        )
3511                except ValueError as e:
3512                    if expect_success:
3513                        self.fail(
3514                            "Failed to send with method <<{name:s}>>; "
3515                            "expected to succeed.\n".format(name=meth_name)
3516                        )
3517                    if not str(e).startswith(meth_name):
3518                        self.fail(
3519                            "Method <<{name:s}>> failed with unexpected "
3520                            "exception message: {exp:s}\n".format(
3521                                name=meth_name, exp=e
3522                            )
3523                        )
3524
3525            for meth_name, recv_meth, expect_success, args in recv_methods:
3526                indata = (data_prefix + meth_name).encode('ascii')
3527                try:
3528                    s.send(indata)
3529                    outdata = recv_meth(*args)
3530                    if outdata != indata.lower():
3531                        self.fail(
3532                            "While receiving with <<{name:s}>> bad data "
3533                            "<<{outdata:r}>> ({nout:d}) received; "
3534                            "expected <<{indata:r}>> ({nin:d})\n".format(
3535                                name=meth_name, outdata=outdata[:20],
3536                                nout=len(outdata),
3537                                indata=indata[:20], nin=len(indata)
3538                            )
3539                        )
3540                except ValueError as e:
3541                    if expect_success:
3542                        self.fail(
3543                            "Failed to receive with method <<{name:s}>>; "
3544                            "expected to succeed.\n".format(name=meth_name)
3545                        )
3546                    if not str(e).startswith(meth_name):
3547                        self.fail(
3548                            "Method <<{name:s}>> failed with unexpected "
3549                            "exception message: {exp:s}\n".format(
3550                                name=meth_name, exp=e
3551                            )
3552                        )
3553                    # consume data
3554                    s.read()
3555
3556            # read(-1, buffer) is supported, even though read(-1) is not
3557            data = b"data"
3558            s.send(data)
3559            buffer = bytearray(len(data))
3560            self.assertEqual(s.read(-1, buffer), len(data))
3561            self.assertEqual(buffer, data)
3562
3563            # sendall accepts bytes-like objects
3564            if ctypes is not None:
3565                ubyte = ctypes.c_ubyte * len(data)
3566                byteslike = ubyte.from_buffer_copy(data)
3567                s.sendall(byteslike)
3568                self.assertEqual(s.read(), data)
3569
3570            # Make sure sendmsg et al are disallowed to avoid
3571            # inadvertent disclosure of data and/or corruption
3572            # of the encrypted data stream
3573            self.assertRaises(NotImplementedError, s.dup)
3574            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3575            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3576            self.assertRaises(NotImplementedError,
3577                              s.recvmsg_into, [bytearray(100)])
3578            s.write(b"over\n")
3579
3580            self.assertRaises(ValueError, s.recv, -1)
3581            self.assertRaises(ValueError, s.read, -1)
3582
3583            s.close()
3584
3585    def test_recv_zero(self):
3586        server = ThreadedEchoServer(CERTFILE)
3587        server.__enter__()
3588        self.addCleanup(server.__exit__, None, None)
3589        s = socket.create_connection((HOST, server.port))
3590        self.addCleanup(s.close)
3591        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3592        self.addCleanup(s.close)
3593
3594        # recv/read(0) should return no data
3595        s.send(b"data")
3596        self.assertEqual(s.recv(0), b"")
3597        self.assertEqual(s.read(0), b"")
3598        self.assertEqual(s.read(), b"data")
3599
3600        # Should not block if the other end sends no data
3601        s.setblocking(False)
3602        self.assertEqual(s.recv(0), b"")
3603        self.assertEqual(s.recv_into(bytearray()), 0)
3604
3605    def test_nonblocking_send(self):
3606        server = ThreadedEchoServer(CERTFILE,
3607                                    certreqs=ssl.CERT_NONE,
3608                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3609                                    cacerts=CERTFILE,
3610                                    chatty=True,
3611                                    connectionchatty=False)
3612        with server:
3613            s = test_wrap_socket(socket.socket(),
3614                                server_side=False,
3615                                certfile=CERTFILE,
3616                                ca_certs=CERTFILE,
3617                                cert_reqs=ssl.CERT_NONE,
3618                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3619            s.connect((HOST, server.port))
3620            s.setblocking(False)
3621
3622            # If we keep sending data, at some point the buffers
3623            # will be full and the call will block
3624            buf = bytearray(8192)
3625            def fill_buffer():
3626                while True:
3627                    s.send(buf)
3628            self.assertRaises((ssl.SSLWantWriteError,
3629                               ssl.SSLWantReadError), fill_buffer)
3630
3631            # Now read all the output and discard it
3632            s.setblocking(True)
3633            s.close()
3634
3635    def test_handshake_timeout(self):
3636        # Issue #5103: SSL handshake must respect the socket timeout
3637        server = socket.socket(socket.AF_INET)
3638        host = "127.0.0.1"
3639        port = support.bind_port(server)
3640        started = threading.Event()
3641        finish = False
3642
3643        def serve():
3644            server.listen()
3645            started.set()
3646            conns = []
3647            while not finish:
3648                r, w, e = select.select([server], [], [], 0.1)
3649                if server in r:
3650                    # Let the socket hang around rather than having
3651                    # it closed by garbage collection.
3652                    conns.append(server.accept()[0])
3653            for sock in conns:
3654                sock.close()
3655
3656        t = threading.Thread(target=serve)
3657        t.start()
3658        started.wait()
3659
3660        try:
3661            try:
3662                c = socket.socket(socket.AF_INET)
3663                c.settimeout(0.2)
3664                c.connect((host, port))
3665                # Will attempt handshake and time out
3666                self.assertRaisesRegex(socket.timeout, "timed out",
3667                                       test_wrap_socket, c)
3668            finally:
3669                c.close()
3670            try:
3671                c = socket.socket(socket.AF_INET)
3672                c = test_wrap_socket(c)
3673                c.settimeout(0.2)
3674                # Will attempt handshake and time out
3675                self.assertRaisesRegex(socket.timeout, "timed out",
3676                                       c.connect, (host, port))
3677            finally:
3678                c.close()
3679        finally:
3680            finish = True
3681            t.join()
3682            server.close()
3683
3684    def test_server_accept(self):
3685        # Issue #16357: accept() on a SSLSocket created through
3686        # SSLContext.wrap_socket().
3687        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3688        context.verify_mode = ssl.CERT_REQUIRED
3689        context.load_verify_locations(SIGNING_CA)
3690        context.load_cert_chain(SIGNED_CERTFILE)
3691        server = socket.socket(socket.AF_INET)
3692        host = "127.0.0.1"
3693        port = support.bind_port(server)
3694        server = context.wrap_socket(server, server_side=True)
3695        self.assertTrue(server.server_side)
3696
3697        evt = threading.Event()
3698        remote = None
3699        peer = None
3700        def serve():
3701            nonlocal remote, peer
3702            server.listen()
3703            # Block on the accept and wait on the connection to close.
3704            evt.set()
3705            remote, peer = server.accept()
3706            remote.send(remote.recv(4))
3707
3708        t = threading.Thread(target=serve)
3709        t.start()
3710        # Client wait until server setup and perform a connect.
3711        evt.wait()
3712        client = context.wrap_socket(socket.socket())
3713        client.connect((host, port))
3714        client.send(b'data')
3715        client.recv()
3716        client_addr = client.getsockname()
3717        client.close()
3718        t.join()
3719        remote.close()
3720        server.close()
3721        # Sanity checks.
3722        self.assertIsInstance(remote, ssl.SSLSocket)
3723        self.assertEqual(peer, client_addr)
3724
3725    def test_getpeercert_enotconn(self):
3726        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3727        with context.wrap_socket(socket.socket()) as sock:
3728            with self.assertRaises(OSError) as cm:
3729                sock.getpeercert()
3730            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3731
3732    def test_do_handshake_enotconn(self):
3733        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3734        with context.wrap_socket(socket.socket()) as sock:
3735            with self.assertRaises(OSError) as cm:
3736                sock.do_handshake()
3737            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3738
3739    def test_no_shared_ciphers(self):
3740        client_context, server_context, hostname = testing_context()
3741        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3742        client_context.options |= ssl.OP_NO_TLSv1_3
3743        # Force different suites on client and server
3744        client_context.set_ciphers("AES128")
3745        server_context.set_ciphers("AES256")
3746        with ThreadedEchoServer(context=server_context) as server:
3747            with client_context.wrap_socket(socket.socket(),
3748                                            server_hostname=hostname) as s:
3749                with self.assertRaises(OSError):
3750                    s.connect((HOST, server.port))
3751        self.assertIn("no shared cipher", server.conn_errors[0])
3752
3753    def test_version_basic(self):
3754        """
3755        Basic tests for SSLSocket.version().
3756        More tests are done in the test_protocol_*() methods.
3757        """
3758        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3759        context.check_hostname = False
3760        context.verify_mode = ssl.CERT_NONE
3761        with ThreadedEchoServer(CERTFILE,
3762                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3763                                chatty=False) as server:
3764            with context.wrap_socket(socket.socket()) as s:
3765                self.assertIs(s.version(), None)
3766                self.assertIs(s._sslobj, None)
3767                s.connect((HOST, server.port))
3768                if IS_OPENSSL_1_1_1 and has_tls_version('TLSv1_3'):
3769                    self.assertEqual(s.version(), 'TLSv1.3')
3770                elif ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
3771                    self.assertEqual(s.version(), 'TLSv1.2')
3772                else:  # 0.9.8 to 1.0.1
3773                    self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
3774            self.assertIs(s._sslobj, None)
3775            self.assertIs(s.version(), None)
3776
3777    @requires_tls_version('TLSv1_3')
3778    def test_tls1_3(self):
3779        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3780        context.load_cert_chain(CERTFILE)
3781        context.options |= (
3782            ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
3783        )
3784        with ThreadedEchoServer(context=context) as server:
3785            with context.wrap_socket(socket.socket()) as s:
3786                s.connect((HOST, server.port))
3787                self.assertIn(s.cipher()[0], {
3788                    'TLS_AES_256_GCM_SHA384',
3789                    'TLS_CHACHA20_POLY1305_SHA256',
3790                    'TLS_AES_128_GCM_SHA256',
3791                })
3792                self.assertEqual(s.version(), 'TLSv1.3')
3793
3794    @requires_minimum_version
3795    @requires_tls_version('TLSv1_2')
3796    def test_min_max_version_tlsv1_2(self):
3797        client_context, server_context, hostname = testing_context()
3798        # client TLSv1.0 to 1.2
3799        client_context.minimum_version = ssl.TLSVersion.TLSv1
3800        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3801        # server only TLSv1.2
3802        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3803        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3804
3805        with ThreadedEchoServer(context=server_context) as server:
3806            with client_context.wrap_socket(socket.socket(),
3807                                            server_hostname=hostname) as s:
3808                s.connect((HOST, server.port))
3809                self.assertEqual(s.version(), 'TLSv1.2')
3810
3811    @requires_minimum_version
3812    @requires_tls_version('TLSv1_1')
3813    def test_min_max_version_tlsv1_1(self):
3814        client_context, server_context, hostname = testing_context()
3815        # client 1.0 to 1.2, server 1.0 to 1.1
3816        client_context.minimum_version = ssl.TLSVersion.TLSv1
3817        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3818        server_context.minimum_version = ssl.TLSVersion.TLSv1
3819        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3820
3821        with ThreadedEchoServer(context=server_context) as server:
3822            with client_context.wrap_socket(socket.socket(),
3823                                            server_hostname=hostname) as s:
3824                s.connect((HOST, server.port))
3825                self.assertEqual(s.version(), 'TLSv1.1')
3826
3827    @requires_minimum_version
3828    @requires_tls_version('TLSv1_2')
3829    def test_min_max_version_mismatch(self):
3830        client_context, server_context, hostname = testing_context()
3831        # client 1.0, server 1.2 (mismatch)
3832        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3833        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3834        client_context.maximum_version = ssl.TLSVersion.TLSv1
3835        client_context.minimum_version = ssl.TLSVersion.TLSv1
3836        with ThreadedEchoServer(context=server_context) as server:
3837            with client_context.wrap_socket(socket.socket(),
3838                                            server_hostname=hostname) as s:
3839                with self.assertRaises(ssl.SSLError) as e:
3840                    s.connect((HOST, server.port))
3841                self.assertIn("alert", str(e.exception))
3842
3843    @requires_minimum_version
3844    @requires_tls_version('SSLv3')
3845    def test_min_max_version_sslv3(self):
3846        client_context, server_context, hostname = testing_context()
3847        server_context.minimum_version = ssl.TLSVersion.SSLv3
3848        client_context.minimum_version = ssl.TLSVersion.SSLv3
3849        client_context.maximum_version = ssl.TLSVersion.SSLv3
3850        with ThreadedEchoServer(context=server_context) as server:
3851            with client_context.wrap_socket(socket.socket(),
3852                                            server_hostname=hostname) as s:
3853                s.connect((HOST, server.port))
3854                self.assertEqual(s.version(), 'SSLv3')
3855
3856    @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
3857    def test_default_ecdh_curve(self):
3858        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
3859        # should be enabled by default on SSL contexts.
3860        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3861        context.load_cert_chain(CERTFILE)
3862        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
3863        # cipher name.
3864        context.options |= ssl.OP_NO_TLSv1_3
3865        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
3866        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
3867        # our default cipher list should prefer ECDH-based ciphers
3868        # automatically.
3869        if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
3870            context.set_ciphers("ECCdraft:ECDH")
3871        with ThreadedEchoServer(context=context) as server:
3872            with context.wrap_socket(socket.socket()) as s:
3873                s.connect((HOST, server.port))
3874                self.assertIn("ECDH", s.cipher()[0])
3875
3876    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
3877                         "'tls-unique' channel binding not available")
3878    def test_tls_unique_channel_binding(self):
3879        """Test tls-unique channel binding."""
3880        if support.verbose:
3881            sys.stdout.write("\n")
3882
3883        client_context, server_context, hostname = testing_context()
3884
3885        server = ThreadedEchoServer(context=server_context,
3886                                    chatty=True,
3887                                    connectionchatty=False)
3888
3889        with server:
3890            with client_context.wrap_socket(
3891                    socket.socket(),
3892                    server_hostname=hostname) as s:
3893                s.connect((HOST, server.port))
3894                # get the data
3895                cb_data = s.get_channel_binding("tls-unique")
3896                if support.verbose:
3897                    sys.stdout.write(
3898                        " got channel binding data: {0!r}\n".format(cb_data))
3899
3900                # check if it is sane
3901                self.assertIsNotNone(cb_data)
3902                if s.version() == 'TLSv1.3':
3903                    self.assertEqual(len(cb_data), 48)
3904                else:
3905                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3906
3907                # and compare with the peers version
3908                s.write(b"CB tls-unique\n")
3909                peer_data_repr = s.read().strip()
3910                self.assertEqual(peer_data_repr,
3911                                 repr(cb_data).encode("us-ascii"))
3912
3913            # now, again
3914            with client_context.wrap_socket(
3915                    socket.socket(),
3916                    server_hostname=hostname) as s:
3917                s.connect((HOST, server.port))
3918                new_cb_data = s.get_channel_binding("tls-unique")
3919                if support.verbose:
3920                    sys.stdout.write(
3921                        "got another channel binding data: {0!r}\n".format(
3922                            new_cb_data)
3923                    )
3924                # is it really unique
3925                self.assertNotEqual(cb_data, new_cb_data)
3926                self.assertIsNotNone(cb_data)
3927                if s.version() == 'TLSv1.3':
3928                    self.assertEqual(len(cb_data), 48)
3929                else:
3930                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3931                s.write(b"CB tls-unique\n")
3932                peer_data_repr = s.read().strip()
3933                self.assertEqual(peer_data_repr,
3934                                 repr(new_cb_data).encode("us-ascii"))
3935
3936    def test_compression(self):
3937        client_context, server_context, hostname = testing_context()
3938        stats = server_params_test(client_context, server_context,
3939                                   chatty=True, connectionchatty=True,
3940                                   sni_name=hostname)
3941        if support.verbose:
3942            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
3943        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
3944
3945    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
3946                         "ssl.OP_NO_COMPRESSION needed for this test")
3947    def test_compression_disabled(self):
3948        client_context, server_context, hostname = testing_context()
3949        client_context.options |= ssl.OP_NO_COMPRESSION
3950        server_context.options |= ssl.OP_NO_COMPRESSION
3951        stats = server_params_test(client_context, server_context,
3952                                   chatty=True, connectionchatty=True,
3953                                   sni_name=hostname)
3954        self.assertIs(stats['compression'], None)
3955
3956    def test_dh_params(self):
3957        # Check we can get a connection with ephemeral Diffie-Hellman
3958        client_context, server_context, hostname = testing_context()
3959        # test scenario needs TLS <= 1.2
3960        client_context.options |= ssl.OP_NO_TLSv1_3
3961        server_context.load_dh_params(DHFILE)
3962        server_context.set_ciphers("kEDH")
3963        server_context.options |= ssl.OP_NO_TLSv1_3
3964        stats = server_params_test(client_context, server_context,
3965                                   chatty=True, connectionchatty=True,
3966                                   sni_name=hostname)
3967        cipher = stats["cipher"][0]
3968        parts = cipher.split("-")
3969        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
3970            self.fail("Non-DH cipher: " + cipher[0])
3971
3972    @unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
3973    @unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
3974    def test_ecdh_curve(self):
3975        # server secp384r1, client auto
3976        client_context, server_context, hostname = testing_context()
3977
3978        server_context.set_ecdh_curve("secp384r1")
3979        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3980        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3981        stats = server_params_test(client_context, server_context,
3982                                   chatty=True, connectionchatty=True,
3983                                   sni_name=hostname)
3984
3985        # server auto, client secp384r1
3986        client_context, server_context, hostname = testing_context()
3987        client_context.set_ecdh_curve("secp384r1")
3988        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3989        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
3990        stats = server_params_test(client_context, server_context,
3991                                   chatty=True, connectionchatty=True,
3992                                   sni_name=hostname)
3993
3994        # server / client curve mismatch
3995        client_context, server_context, hostname = testing_context()
3996        client_context.set_ecdh_curve("prime256v1")
3997        server_context.set_ecdh_curve("secp384r1")
3998        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
3999        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
4000        try:
4001            stats = server_params_test(client_context, server_context,
4002                                       chatty=True, connectionchatty=True,
4003                                       sni_name=hostname)
4004        except ssl.SSLError:
4005            pass
4006        else:
4007            # OpenSSL 1.0.2 does not fail although it should.
4008            if IS_OPENSSL_1_1_0:
4009                self.fail("mismatch curve did not fail")
4010
4011    def test_selected_alpn_protocol(self):
4012        # selected_alpn_protocol() is None unless ALPN is used.
4013        client_context, server_context, hostname = testing_context()
4014        stats = server_params_test(client_context, server_context,
4015                                   chatty=True, connectionchatty=True,
4016                                   sni_name=hostname)
4017        self.assertIs(stats['client_alpn_protocol'], None)
4018
4019    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
4020    def test_selected_alpn_protocol_if_server_uses_alpn(self):
4021        # selected_alpn_protocol() is None unless ALPN is used by the client.
4022        client_context, server_context, hostname = testing_context()
4023        server_context.set_alpn_protocols(['foo', 'bar'])
4024        stats = server_params_test(client_context, server_context,
4025                                   chatty=True, connectionchatty=True,
4026                                   sni_name=hostname)
4027        self.assertIs(stats['client_alpn_protocol'], None)
4028
4029    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
4030    def test_alpn_protocols(self):
4031        server_protocols = ['foo', 'bar', 'milkshake']
4032        protocol_tests = [
4033            (['foo', 'bar'], 'foo'),
4034            (['bar', 'foo'], 'foo'),
4035            (['milkshake'], 'milkshake'),
4036            (['http/3.0', 'http/4.0'], None)
4037        ]
4038        for client_protocols, expected in protocol_tests:
4039            client_context, server_context, hostname = testing_context()
4040            server_context.set_alpn_protocols(server_protocols)
4041            client_context.set_alpn_protocols(client_protocols)
4042
4043            try:
4044                stats = server_params_test(client_context,
4045                                           server_context,
4046                                           chatty=True,
4047                                           connectionchatty=True,
4048                                           sni_name=hostname)
4049            except ssl.SSLError as e:
4050                stats = e
4051
4052            if (expected is None and IS_OPENSSL_1_1_0
4053                    and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
4054                # OpenSSL 1.1.0 to 1.1.0e raises handshake error
4055                self.assertIsInstance(stats, ssl.SSLError)
4056            else:
4057                msg = "failed trying %s (s) and %s (c).\n" \
4058                    "was expecting %s, but got %%s from the %%s" \
4059                        % (str(server_protocols), str(client_protocols),
4060                            str(expected))
4061                client_result = stats['client_alpn_protocol']
4062                self.assertEqual(client_result, expected,
4063                                 msg % (client_result, "client"))
4064                server_result = stats['server_alpn_protocols'][-1] \
4065                    if len(stats['server_alpn_protocols']) else 'nothing'
4066                self.assertEqual(server_result, expected,
4067                                 msg % (server_result, "server"))
4068
4069    def test_selected_npn_protocol(self):
4070        # selected_npn_protocol() is None unless NPN is used
4071        client_context, server_context, hostname = testing_context()
4072        stats = server_params_test(client_context, server_context,
4073                                   chatty=True, connectionchatty=True,
4074                                   sni_name=hostname)
4075        self.assertIs(stats['client_npn_protocol'], None)
4076
4077    @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
4078    def test_npn_protocols(self):
4079        server_protocols = ['http/1.1', 'spdy/2']
4080        protocol_tests = [
4081            (['http/1.1', 'spdy/2'], 'http/1.1'),
4082            (['spdy/2', 'http/1.1'], 'http/1.1'),
4083            (['spdy/2', 'test'], 'spdy/2'),
4084            (['abc', 'def'], 'abc')
4085        ]
4086        for client_protocols, expected in protocol_tests:
4087            client_context, server_context, hostname = testing_context()
4088            server_context.set_npn_protocols(server_protocols)
4089            client_context.set_npn_protocols(client_protocols)
4090            stats = server_params_test(client_context, server_context,
4091                                       chatty=True, connectionchatty=True,
4092                                       sni_name=hostname)
4093            msg = "failed trying %s (s) and %s (c).\n" \
4094                  "was expecting %s, but got %%s from the %%s" \
4095                      % (str(server_protocols), str(client_protocols),
4096                         str(expected))
4097            client_result = stats['client_npn_protocol']
4098            self.assertEqual(client_result, expected, msg % (client_result, "client"))
4099            server_result = stats['server_npn_protocols'][-1] \
4100                if len(stats['server_npn_protocols']) else 'nothing'
4101            self.assertEqual(server_result, expected, msg % (server_result, "server"))
4102
4103    def sni_contexts(self):
4104        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4105        server_context.load_cert_chain(SIGNED_CERTFILE)
4106        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4107        other_context.load_cert_chain(SIGNED_CERTFILE2)
4108        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4109        client_context.load_verify_locations(SIGNING_CA)
4110        return server_context, other_context, client_context
4111
4112    def check_common_name(self, stats, name):
4113        cert = stats['peercert']
4114        self.assertIn((('commonName', name),), cert['subject'])
4115
4116    @needs_sni
4117    def test_sni_callback(self):
4118        calls = []
4119        server_context, other_context, client_context = self.sni_contexts()
4120
4121        client_context.check_hostname = False
4122
4123        def servername_cb(ssl_sock, server_name, initial_context):
4124            calls.append((server_name, initial_context))
4125            if server_name is not None:
4126                ssl_sock.context = other_context
4127        server_context.set_servername_callback(servername_cb)
4128
4129        stats = server_params_test(client_context, server_context,
4130                                   chatty=True,
4131                                   sni_name='supermessage')
4132        # The hostname was fetched properly, and the certificate was
4133        # changed for the connection.
4134        self.assertEqual(calls, [("supermessage", server_context)])
4135        # CERTFILE4 was selected
4136        self.check_common_name(stats, 'fakehostname')
4137
4138        calls = []
4139        # The callback is called with server_name=None
4140        stats = server_params_test(client_context, server_context,
4141                                   chatty=True,
4142                                   sni_name=None)
4143        self.assertEqual(calls, [(None, server_context)])
4144        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4145
4146        # Check disabling the callback
4147        calls = []
4148        server_context.set_servername_callback(None)
4149
4150        stats = server_params_test(client_context, server_context,
4151                                   chatty=True,
4152                                   sni_name='notfunny')
4153        # Certificate didn't change
4154        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4155        self.assertEqual(calls, [])
4156
4157    @needs_sni
4158    def test_sni_callback_alert(self):
4159        # Returning a TLS alert is reflected to the connecting client
4160        server_context, other_context, client_context = self.sni_contexts()
4161
4162        def cb_returning_alert(ssl_sock, server_name, initial_context):
4163            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4164        server_context.set_servername_callback(cb_returning_alert)
4165        with self.assertRaises(ssl.SSLError) as cm:
4166            stats = server_params_test(client_context, server_context,
4167                                       chatty=False,
4168                                       sni_name='supermessage')
4169        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4170
4171    @needs_sni
4172    def test_sni_callback_raising(self):
4173        # Raising fails the connection with a TLS handshake failure alert.
4174        server_context, other_context, client_context = self.sni_contexts()
4175
4176        def cb_raising(ssl_sock, server_name, initial_context):
4177            1/0
4178        server_context.set_servername_callback(cb_raising)
4179
4180        with support.catch_unraisable_exception() as catch:
4181            with self.assertRaises(ssl.SSLError) as cm:
4182                stats = server_params_test(client_context, server_context,
4183                                           chatty=False,
4184                                           sni_name='supermessage')
4185
4186            self.assertEqual(cm.exception.reason,
4187                             'SSLV3_ALERT_HANDSHAKE_FAILURE')
4188            self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
4189
4190    @needs_sni
4191    def test_sni_callback_wrong_return_type(self):
4192        # Returning the wrong return type terminates the TLS connection
4193        # with an internal error alert.
4194        server_context, other_context, client_context = self.sni_contexts()
4195
4196        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4197            return "foo"
4198        server_context.set_servername_callback(cb_wrong_return_type)
4199
4200        with support.catch_unraisable_exception() as catch:
4201            with self.assertRaises(ssl.SSLError) as cm:
4202                stats = server_params_test(client_context, server_context,
4203                                           chatty=False,
4204                                           sni_name='supermessage')
4205
4206
4207            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4208            self.assertEqual(catch.unraisable.exc_type, TypeError)
4209
4210    def test_shared_ciphers(self):
4211        client_context, server_context, hostname = testing_context()
4212        client_context.set_ciphers("AES128:AES256")
4213        server_context.set_ciphers("AES256")
4214        expected_algs = [
4215            "AES256", "AES-256",
4216            # TLS 1.3 ciphers are always enabled
4217            "TLS_CHACHA20", "TLS_AES",
4218        ]
4219
4220        stats = server_params_test(client_context, server_context,
4221                                   sni_name=hostname)
4222        ciphers = stats['server_shared_ciphers'][0]
4223        self.assertGreater(len(ciphers), 0)
4224        for name, tls_version, bits in ciphers:
4225            if not any(alg in name for alg in expected_algs):
4226                self.fail(name)
4227
4228    def test_read_write_after_close_raises_valuerror(self):
4229        client_context, server_context, hostname = testing_context()
4230        server = ThreadedEchoServer(context=server_context, chatty=False)
4231
4232        with server:
4233            s = client_context.wrap_socket(socket.socket(),
4234                                           server_hostname=hostname)
4235            s.connect((HOST, server.port))
4236            s.close()
4237
4238            self.assertRaises(ValueError, s.read, 1024)
4239            self.assertRaises(ValueError, s.write, b'hello')
4240
4241    def test_sendfile(self):
4242        TEST_DATA = b"x" * 512
4243        with open(support.TESTFN, 'wb') as f:
4244            f.write(TEST_DATA)
4245        self.addCleanup(support.unlink, support.TESTFN)
4246        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
4247        context.verify_mode = ssl.CERT_REQUIRED
4248        context.load_verify_locations(SIGNING_CA)
4249        context.load_cert_chain(SIGNED_CERTFILE)
4250        server = ThreadedEchoServer(context=context, chatty=False)
4251        with server:
4252            with context.wrap_socket(socket.socket()) as s:
4253                s.connect((HOST, server.port))
4254                with open(support.TESTFN, 'rb') as file:
4255                    s.sendfile(file)
4256                    self.assertEqual(s.recv(1024), TEST_DATA)
4257
4258    def test_session(self):
4259        client_context, server_context, hostname = testing_context()
4260        # TODO: sessions aren't compatible with TLSv1.3 yet
4261        client_context.options |= ssl.OP_NO_TLSv1_3
4262
4263        # first connection without session
4264        stats = server_params_test(client_context, server_context,
4265                                   sni_name=hostname)
4266        session = stats['session']
4267        self.assertTrue(session.id)
4268        self.assertGreater(session.time, 0)
4269        self.assertGreater(session.timeout, 0)
4270        self.assertTrue(session.has_ticket)
4271        if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
4272            self.assertGreater(session.ticket_lifetime_hint, 0)
4273        self.assertFalse(stats['session_reused'])
4274        sess_stat = server_context.session_stats()
4275        self.assertEqual(sess_stat['accept'], 1)
4276        self.assertEqual(sess_stat['hits'], 0)
4277
4278        # reuse session
4279        stats = server_params_test(client_context, server_context,
4280                                   session=session, sni_name=hostname)
4281        sess_stat = server_context.session_stats()
4282        self.assertEqual(sess_stat['accept'], 2)
4283        self.assertEqual(sess_stat['hits'], 1)
4284        self.assertTrue(stats['session_reused'])
4285        session2 = stats['session']
4286        self.assertEqual(session2.id, session.id)
4287        self.assertEqual(session2, session)
4288        self.assertIsNot(session2, session)
4289        self.assertGreaterEqual(session2.time, session.time)
4290        self.assertGreaterEqual(session2.timeout, session.timeout)
4291
4292        # another one without session
4293        stats = server_params_test(client_context, server_context,
4294                                   sni_name=hostname)
4295        self.assertFalse(stats['session_reused'])
4296        session3 = stats['session']
4297        self.assertNotEqual(session3.id, session.id)
4298        self.assertNotEqual(session3, session)
4299        sess_stat = server_context.session_stats()
4300        self.assertEqual(sess_stat['accept'], 3)
4301        self.assertEqual(sess_stat['hits'], 1)
4302
4303        # reuse session again
4304        stats = server_params_test(client_context, server_context,
4305                                   session=session, sni_name=hostname)
4306        self.assertTrue(stats['session_reused'])
4307        session4 = stats['session']
4308        self.assertEqual(session4.id, session.id)
4309        self.assertEqual(session4, session)
4310        self.assertGreaterEqual(session4.time, session.time)
4311        self.assertGreaterEqual(session4.timeout, session.timeout)
4312        sess_stat = server_context.session_stats()
4313        self.assertEqual(sess_stat['accept'], 4)
4314        self.assertEqual(sess_stat['hits'], 2)
4315
4316    def test_session_handling(self):
4317        client_context, server_context, hostname = testing_context()
4318        client_context2, _, _ = testing_context()
4319
4320        # TODO: session reuse does not work with TLSv1.3
4321        client_context.options |= ssl.OP_NO_TLSv1_3
4322        client_context2.options |= ssl.OP_NO_TLSv1_3
4323
4324        server = ThreadedEchoServer(context=server_context, chatty=False)
4325        with server:
4326            with client_context.wrap_socket(socket.socket(),
4327                                            server_hostname=hostname) as s:
4328                # session is None before handshake
4329                self.assertEqual(s.session, None)
4330                self.assertEqual(s.session_reused, None)
4331                s.connect((HOST, server.port))
4332                session = s.session
4333                self.assertTrue(session)
4334                with self.assertRaises(TypeError) as e:
4335                    s.session = object
4336                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4337
4338            with client_context.wrap_socket(socket.socket(),
4339                                            server_hostname=hostname) as s:
4340                s.connect((HOST, server.port))
4341                # cannot set session after handshake
4342                with self.assertRaises(ValueError) as e:
4343                    s.session = session
4344                self.assertEqual(str(e.exception),
4345                                 'Cannot set session after handshake.')
4346
4347            with client_context.wrap_socket(socket.socket(),
4348                                            server_hostname=hostname) as s:
4349                # can set session before handshake and before the
4350                # connection was established
4351                s.session = session
4352                s.connect((HOST, server.port))
4353                self.assertEqual(s.session.id, session.id)
4354                self.assertEqual(s.session, session)
4355                self.assertEqual(s.session_reused, True)
4356
4357            with client_context2.wrap_socket(socket.socket(),
4358                                             server_hostname=hostname) as s:
4359                # cannot re-use session with a different SSLContext
4360                with self.assertRaises(ValueError) as e:
4361                    s.session = session
4362                    s.connect((HOST, server.port))
4363                self.assertEqual(str(e.exception),
4364                                 'Session refers to a different SSLContext.')
4365
4366
4367@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
4368class TestPostHandshakeAuth(unittest.TestCase):
4369    def test_pha_setter(self):
4370        protocols = [
4371            ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4372        ]
4373        for protocol in protocols:
4374            ctx = ssl.SSLContext(protocol)
4375            self.assertEqual(ctx.post_handshake_auth, False)
4376
4377            ctx.post_handshake_auth = True
4378            self.assertEqual(ctx.post_handshake_auth, True)
4379
4380            ctx.verify_mode = ssl.CERT_REQUIRED
4381            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4382            self.assertEqual(ctx.post_handshake_auth, True)
4383
4384            ctx.post_handshake_auth = False
4385            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4386            self.assertEqual(ctx.post_handshake_auth, False)
4387
4388            ctx.verify_mode = ssl.CERT_OPTIONAL
4389            ctx.post_handshake_auth = True
4390            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4391            self.assertEqual(ctx.post_handshake_auth, True)
4392
4393    def test_pha_required(self):
4394        client_context, server_context, hostname = testing_context()
4395        server_context.post_handshake_auth = True
4396        server_context.verify_mode = ssl.CERT_REQUIRED
4397        client_context.post_handshake_auth = True
4398        client_context.load_cert_chain(SIGNED_CERTFILE)
4399
4400        server = ThreadedEchoServer(context=server_context, chatty=False)
4401        with server:
4402            with client_context.wrap_socket(socket.socket(),
4403                                            server_hostname=hostname) as s:
4404                s.connect((HOST, server.port))
4405                s.write(b'HASCERT')
4406                self.assertEqual(s.recv(1024), b'FALSE\n')
4407                s.write(b'PHA')
4408                self.assertEqual(s.recv(1024), b'OK\n')
4409                s.write(b'HASCERT')
4410                self.assertEqual(s.recv(1024), b'TRUE\n')
4411                # PHA method just returns true when cert is already available
4412                s.write(b'PHA')
4413                self.assertEqual(s.recv(1024), b'OK\n')
4414                s.write(b'GETCERT')
4415                cert_text = s.recv(4096).decode('us-ascii')
4416                self.assertIn('Python Software Foundation CA', cert_text)
4417
4418    def test_pha_required_nocert(self):
4419        client_context, server_context, hostname = testing_context()
4420        server_context.post_handshake_auth = True
4421        server_context.verify_mode = ssl.CERT_REQUIRED
4422        client_context.post_handshake_auth = True
4423
4424        # Ignore expected SSLError in ConnectionHandler of ThreadedEchoServer
4425        # (it is only raised sometimes on Windows)
4426        with support.catch_threading_exception() as cm:
4427            server = ThreadedEchoServer(context=server_context, chatty=False)
4428            with server:
4429                with client_context.wrap_socket(socket.socket(),
4430                                                server_hostname=hostname) as s:
4431                    s.connect((HOST, server.port))
4432                    s.write(b'PHA')
4433                    # receive CertificateRequest
4434                    self.assertEqual(s.recv(1024), b'OK\n')
4435                    # send empty Certificate + Finish
4436                    s.write(b'HASCERT')
4437                    # receive alert
4438                    with self.assertRaisesRegex(
4439                            ssl.SSLError,
4440                            'tlsv13 alert certificate required'):
4441                        s.recv(1024)
4442
4443    def test_pha_optional(self):
4444        if support.verbose:
4445            sys.stdout.write("\n")
4446
4447        client_context, server_context, hostname = testing_context()
4448        server_context.post_handshake_auth = True
4449        server_context.verify_mode = ssl.CERT_REQUIRED
4450        client_context.post_handshake_auth = True
4451        client_context.load_cert_chain(SIGNED_CERTFILE)
4452
4453        # check CERT_OPTIONAL
4454        server_context.verify_mode = ssl.CERT_OPTIONAL
4455        server = ThreadedEchoServer(context=server_context, chatty=False)
4456        with server:
4457            with client_context.wrap_socket(socket.socket(),
4458                                            server_hostname=hostname) as s:
4459                s.connect((HOST, server.port))
4460                s.write(b'HASCERT')
4461                self.assertEqual(s.recv(1024), b'FALSE\n')
4462                s.write(b'PHA')
4463                self.assertEqual(s.recv(1024), b'OK\n')
4464                s.write(b'HASCERT')
4465                self.assertEqual(s.recv(1024), b'TRUE\n')
4466
4467    def test_pha_optional_nocert(self):
4468        if support.verbose:
4469            sys.stdout.write("\n")
4470
4471        client_context, server_context, hostname = testing_context()
4472        server_context.post_handshake_auth = True
4473        server_context.verify_mode = ssl.CERT_OPTIONAL
4474        client_context.post_handshake_auth = True
4475
4476        server = ThreadedEchoServer(context=server_context, chatty=False)
4477        with server:
4478            with client_context.wrap_socket(socket.socket(),
4479                                            server_hostname=hostname) as s:
4480                s.connect((HOST, server.port))
4481                s.write(b'HASCERT')
4482                self.assertEqual(s.recv(1024), b'FALSE\n')
4483                s.write(b'PHA')
4484                self.assertEqual(s.recv(1024), b'OK\n')
4485                # optional doesn't fail when client does not have a cert
4486                s.write(b'HASCERT')
4487                self.assertEqual(s.recv(1024), b'FALSE\n')
4488
4489    def test_pha_no_pha_client(self):
4490        client_context, server_context, hostname = testing_context()
4491        server_context.post_handshake_auth = True
4492        server_context.verify_mode = ssl.CERT_REQUIRED
4493        client_context.load_cert_chain(SIGNED_CERTFILE)
4494
4495        server = ThreadedEchoServer(context=server_context, chatty=False)
4496        with server:
4497            with client_context.wrap_socket(socket.socket(),
4498                                            server_hostname=hostname) as s:
4499                s.connect((HOST, server.port))
4500                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4501                    s.verify_client_post_handshake()
4502                s.write(b'PHA')
4503                self.assertIn(b'extension not received', s.recv(1024))
4504
4505    def test_pha_no_pha_server(self):
4506        # server doesn't have PHA enabled, cert is requested in handshake
4507        client_context, server_context, hostname = testing_context()
4508        server_context.verify_mode = ssl.CERT_REQUIRED
4509        client_context.post_handshake_auth = True
4510        client_context.load_cert_chain(SIGNED_CERTFILE)
4511
4512        server = ThreadedEchoServer(context=server_context, chatty=False)
4513        with server:
4514            with client_context.wrap_socket(socket.socket(),
4515                                            server_hostname=hostname) as s:
4516                s.connect((HOST, server.port))
4517                s.write(b'HASCERT')
4518                self.assertEqual(s.recv(1024), b'TRUE\n')
4519                # PHA doesn't fail if there is already a cert
4520                s.write(b'PHA')
4521                self.assertEqual(s.recv(1024), b'OK\n')
4522                s.write(b'HASCERT')
4523                self.assertEqual(s.recv(1024), b'TRUE\n')
4524
4525    def test_pha_not_tls13(self):
4526        # TLS 1.2
4527        client_context, server_context, hostname = testing_context()
4528        server_context.verify_mode = ssl.CERT_REQUIRED
4529        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4530        client_context.post_handshake_auth = True
4531        client_context.load_cert_chain(SIGNED_CERTFILE)
4532
4533        server = ThreadedEchoServer(context=server_context, chatty=False)
4534        with server:
4535            with client_context.wrap_socket(socket.socket(),
4536                                            server_hostname=hostname) as s:
4537                s.connect((HOST, server.port))
4538                # PHA fails for TLS != 1.3
4539                s.write(b'PHA')
4540                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4541
4542    def test_bpo37428_pha_cert_none(self):
4543        # verify that post_handshake_auth does not implicitly enable cert
4544        # validation.
4545        hostname = SIGNED_CERTFILE_HOSTNAME
4546        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4547        client_context.post_handshake_auth = True
4548        client_context.load_cert_chain(SIGNED_CERTFILE)
4549        # no cert validation and CA on client side
4550        client_context.check_hostname = False
4551        client_context.verify_mode = ssl.CERT_NONE
4552
4553        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4554        server_context.load_cert_chain(SIGNED_CERTFILE)
4555        server_context.load_verify_locations(SIGNING_CA)
4556        server_context.post_handshake_auth = True
4557        server_context.verify_mode = ssl.CERT_REQUIRED
4558
4559        server = ThreadedEchoServer(context=server_context, chatty=False)
4560        with server:
4561            with client_context.wrap_socket(socket.socket(),
4562                                            server_hostname=hostname) as s:
4563                s.connect((HOST, server.port))
4564                s.write(b'HASCERT')
4565                self.assertEqual(s.recv(1024), b'FALSE\n')
4566                s.write(b'PHA')
4567                self.assertEqual(s.recv(1024), b'OK\n')
4568                s.write(b'HASCERT')
4569                self.assertEqual(s.recv(1024), b'TRUE\n')
4570                # server cert has not been validated
4571                self.assertEqual(s.getpeercert(), {})
4572
4573
4574HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
4575requires_keylog = unittest.skipUnless(
4576    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
4577
4578class TestSSLDebug(unittest.TestCase):
4579
4580    def keylog_lines(self, fname=support.TESTFN):
4581        with open(fname) as f:
4582            return len(list(f))
4583
4584    @requires_keylog
4585    def test_keylog_defaults(self):
4586        self.addCleanup(support.unlink, support.TESTFN)
4587        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4588        self.assertEqual(ctx.keylog_filename, None)
4589
4590        self.assertFalse(os.path.isfile(support.TESTFN))
4591        ctx.keylog_filename = support.TESTFN
4592        self.assertEqual(ctx.keylog_filename, support.TESTFN)
4593        self.assertTrue(os.path.isfile(support.TESTFN))
4594        self.assertEqual(self.keylog_lines(), 1)
4595
4596        ctx.keylog_filename = None
4597        self.assertEqual(ctx.keylog_filename, None)
4598
4599        with self.assertRaises((IsADirectoryError, PermissionError)):
4600            # Windows raises PermissionError
4601            ctx.keylog_filename = os.path.dirname(
4602                os.path.abspath(support.TESTFN))
4603
4604        with self.assertRaises(TypeError):
4605            ctx.keylog_filename = 1
4606
4607    @requires_keylog
4608    def test_keylog_filename(self):
4609        self.addCleanup(support.unlink, support.TESTFN)
4610        client_context, server_context, hostname = testing_context()
4611
4612        client_context.keylog_filename = support.TESTFN
4613        server = ThreadedEchoServer(context=server_context, chatty=False)
4614        with server:
4615            with client_context.wrap_socket(socket.socket(),
4616                                            server_hostname=hostname) as s:
4617                s.connect((HOST, server.port))
4618        # header, 5 lines for TLS 1.3
4619        self.assertEqual(self.keylog_lines(), 6)
4620
4621        client_context.keylog_filename = None
4622        server_context.keylog_filename = support.TESTFN
4623        server = ThreadedEchoServer(context=server_context, chatty=False)
4624        with server:
4625            with client_context.wrap_socket(socket.socket(),
4626                                            server_hostname=hostname) as s:
4627                s.connect((HOST, server.port))
4628        self.assertGreaterEqual(self.keylog_lines(), 11)
4629
4630        client_context.keylog_filename = support.TESTFN
4631        server_context.keylog_filename = support.TESTFN
4632        server = ThreadedEchoServer(context=server_context, chatty=False)
4633        with server:
4634            with client_context.wrap_socket(socket.socket(),
4635                                            server_hostname=hostname) as s:
4636                s.connect((HOST, server.port))
4637        self.assertGreaterEqual(self.keylog_lines(), 21)
4638
4639        client_context.keylog_filename = None
4640        server_context.keylog_filename = None
4641
4642    @requires_keylog
4643    @unittest.skipIf(sys.flags.ignore_environment,
4644                     "test is not compatible with ignore_environment")
4645    def test_keylog_env(self):
4646        self.addCleanup(support.unlink, support.TESTFN)
4647        with unittest.mock.patch.dict(os.environ):
4648            os.environ['SSLKEYLOGFILE'] = support.TESTFN
4649            self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN)
4650
4651            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4652            self.assertEqual(ctx.keylog_filename, None)
4653
4654            ctx = ssl.create_default_context()
4655            self.assertEqual(ctx.keylog_filename, support.TESTFN)
4656
4657            ctx = ssl._create_stdlib_context()
4658            self.assertEqual(ctx.keylog_filename, support.TESTFN)
4659
4660    def test_msg_callback(self):
4661        client_context, server_context, hostname = testing_context()
4662
4663        def msg_cb(conn, direction, version, content_type, msg_type, data):
4664            pass
4665
4666        self.assertIs(client_context._msg_callback, None)
4667        client_context._msg_callback = msg_cb
4668        self.assertIs(client_context._msg_callback, msg_cb)
4669        with self.assertRaises(TypeError):
4670            client_context._msg_callback = object()
4671
4672    def test_msg_callback_tls12(self):
4673        client_context, server_context, hostname = testing_context()
4674        client_context.options |= ssl.OP_NO_TLSv1_3
4675
4676        msg = []
4677
4678        def msg_cb(conn, direction, version, content_type, msg_type, data):
4679            self.assertIsInstance(conn, ssl.SSLSocket)
4680            self.assertIsInstance(data, bytes)
4681            self.assertIn(direction, {'read', 'write'})
4682            msg.append((direction, version, content_type, msg_type))
4683
4684        client_context._msg_callback = msg_cb
4685
4686        server = ThreadedEchoServer(context=server_context, chatty=False)
4687        with server:
4688            with client_context.wrap_socket(socket.socket(),
4689                                            server_hostname=hostname) as s:
4690                s.connect((HOST, server.port))
4691
4692        self.assertIn(
4693            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
4694             _TLSMessageType.SERVER_KEY_EXCHANGE),
4695            msg
4696        )
4697        self.assertIn(
4698            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
4699             _TLSMessageType.CHANGE_CIPHER_SPEC),
4700            msg
4701        )
4702
4703
4704def test_main(verbose=False):
4705    if support.verbose:
4706        plats = {
4707            'Mac': platform.mac_ver,
4708            'Windows': platform.win32_ver,
4709        }
4710        for name, func in plats.items():
4711            plat = func()
4712            if plat and plat[0]:
4713                plat = '%s %r' % (name, plat)
4714                break
4715        else:
4716            plat = repr(platform.platform())
4717        print("test_ssl: testing with %r %r" %
4718            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
4719        print("          under %s" % plat)
4720        print("          HAS_SNI = %r" % ssl.HAS_SNI)
4721        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
4722        try:
4723            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
4724        except AttributeError:
4725            pass
4726
4727    for filename in [
4728        CERTFILE, BYTES_CERTFILE,
4729        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
4730        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
4731        BADCERT, BADKEY, EMPTYCERT]:
4732        if not os.path.exists(filename):
4733            raise support.TestFailed("Can't read certificate file %r" % filename)
4734
4735    tests = [
4736        ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
4737        SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
4738        TestPostHandshakeAuth, TestSSLDebug
4739    ]
4740
4741    if support.is_resource_enabled('network'):
4742        tests.append(NetworkedTests)
4743
4744    thread_info = support.threading_setup()
4745    try:
4746        support.run_unittest(*tests)
4747    finally:
4748        support.threading_cleanup(*thread_info)
4749
4750if __name__ == "__main__":
4751    test_main()
4752