• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (C) Jean-Paul Calderone
2# See LICENSE for details.
3
4"""
5Unit tests for :mod:`OpenSSL.SSL`.
6"""
7
8import datetime
9import gc
10import sys
11import uuid
12
13from gc import collect, get_referrers
14from errno import (
15    EAFNOSUPPORT,
16    ECONNREFUSED,
17    EINPROGRESS,
18    EWOULDBLOCK,
19    EPIPE,
20    ESHUTDOWN,
21)
22from sys import platform, getfilesystemencoding
23from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket
24from os import makedirs
25from os.path import join
26from weakref import ref
27from warnings import simplefilter
28
29import flaky
30
31import pytest
32
33from pretend import raiser
34
35from six import PY2, text_type
36
37from cryptography import x509
38from cryptography.hazmat.backends import default_backend
39from cryptography.hazmat.primitives import hashes
40from cryptography.hazmat.primitives import serialization
41from cryptography.hazmat.primitives.asymmetric import rsa
42from cryptography.x509.oid import NameOID
43
44
45from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
46from OpenSSL.crypto import PKey, X509, X509Extension, X509Store
47from OpenSSL.crypto import dump_privatekey, load_privatekey
48from OpenSSL.crypto import dump_certificate, load_certificate
49from OpenSSL.crypto import get_elliptic_curves
50
51from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS
52from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON
53from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
54from OpenSSL.SSL import (
55    SSLv2_METHOD,
56    SSLv3_METHOD,
57    SSLv23_METHOD,
58    TLSv1_METHOD,
59    TLSv1_1_METHOD,
60    TLSv1_2_METHOD,
61)
62from OpenSSL.SSL import OP_SINGLE_DH_USE, OP_NO_SSLv2, OP_NO_SSLv3
63from OpenSSL.SSL import (
64    VERIFY_PEER,
65    VERIFY_FAIL_IF_NO_PEER_CERT,
66    VERIFY_CLIENT_ONCE,
67    VERIFY_NONE,
68)
69
70from OpenSSL import SSL
71from OpenSSL.SSL import (
72    SESS_CACHE_OFF,
73    SESS_CACHE_CLIENT,
74    SESS_CACHE_SERVER,
75    SESS_CACHE_BOTH,
76    SESS_CACHE_NO_AUTO_CLEAR,
77    SESS_CACHE_NO_INTERNAL_LOOKUP,
78    SESS_CACHE_NO_INTERNAL_STORE,
79    SESS_CACHE_NO_INTERNAL,
80)
81
82from OpenSSL.SSL import (
83    Error,
84    SysCallError,
85    WantReadError,
86    WantWriteError,
87    ZeroReturnError,
88)
89from OpenSSL.SSL import Context, Session, Connection, SSLeay_version
90from OpenSSL.SSL import _make_requires
91
92from OpenSSL._util import ffi as _ffi, lib as _lib
93
94from OpenSSL.SSL import (
95    OP_NO_QUERY_MTU,
96    OP_COOKIE_EXCHANGE,
97    OP_NO_TICKET,
98    OP_NO_COMPRESSION,
99    MODE_RELEASE_BUFFERS,
100    NO_OVERLAPPING_PROTOCOLS,
101)
102
103from OpenSSL.SSL import (
104    SSL_ST_CONNECT,
105    SSL_ST_ACCEPT,
106    SSL_ST_MASK,
107    SSL_CB_LOOP,
108    SSL_CB_EXIT,
109    SSL_CB_READ,
110    SSL_CB_WRITE,
111    SSL_CB_ALERT,
112    SSL_CB_READ_ALERT,
113    SSL_CB_WRITE_ALERT,
114    SSL_CB_ACCEPT_LOOP,
115    SSL_CB_ACCEPT_EXIT,
116    SSL_CB_CONNECT_LOOP,
117    SSL_CB_CONNECT_EXIT,
118    SSL_CB_HANDSHAKE_START,
119    SSL_CB_HANDSHAKE_DONE,
120)
121
122try:
123    from OpenSSL.SSL import (
124        SSL_ST_INIT,
125        SSL_ST_BEFORE,
126        SSL_ST_OK,
127        SSL_ST_RENEGOTIATE,
128    )
129except ImportError:
130    SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None
131
132from .util import WARNING_TYPE_EXPECTED, NON_ASCII, is_consistent_type
133from .test_crypto import (
134    client_cert_pem,
135    client_key_pem,
136    server_cert_pem,
137    server_key_pem,
138    root_cert_pem,
139    root_key_pem,
140)
141
142
143# openssl dhparam 2048 -out dh-2048.pem
144dhparam = """\
145-----BEGIN DH PARAMETERS-----
146MIIBCAKCAQEA2F5e976d/GjsaCdKv5RMWL/YV7fq1UUWpPAer5fDXflLMVUuYXxE
1473m3ayZob9lbpgEU0jlPAsXHfQPGxpKmvhv+xV26V/DEoukED8JeZUY/z4pigoptl
148+8+TYdNNE/rFSZQFXIp+v2D91IEgmHBnZlKFSbKR+p8i0KjExXGjU6ji3S5jkOku
149ogikc7df1Ui0hWNJCmTjExq07aXghk97PsdFSxjdawuG3+vos5bnNoUwPLYlFc/z
150ITYG0KXySiCLi4UDlXTZTz7u/+OYczPEgqa/JPUddbM/kfvaRAnjY38cfQ7qXf8Y
151i5s5yYK7a/0eWxxRr2qraYaUj8RwDpH9CwIBAg==
152-----END DH PARAMETERS-----
153"""
154
155
156skip_if_py3 = pytest.mark.skipif(not PY2, reason="Python 2 only")
157
158
159def socket_any_family():
160    try:
161        return socket(AF_INET)
162    except error as e:
163        if e.errno == EAFNOSUPPORT:
164            return socket(AF_INET6)
165        raise
166
167
168def loopback_address(socket):
169    if socket.family == AF_INET:
170        return "127.0.0.1"
171    else:
172        assert socket.family == AF_INET6
173        return "::1"
174
175
176def join_bytes_or_unicode(prefix, suffix):
177    """
178    Join two path components of either ``bytes`` or ``unicode``.
179
180    The return type is the same as the type of ``prefix``.
181    """
182    # If the types are the same, nothing special is necessary.
183    if type(prefix) == type(suffix):
184        return join(prefix, suffix)
185
186    # Otherwise, coerce suffix to the type of prefix.
187    if isinstance(prefix, text_type):
188        return join(prefix, suffix.decode(getfilesystemencoding()))
189    else:
190        return join(prefix, suffix.encode(getfilesystemencoding()))
191
192
193def verify_cb(conn, cert, errnum, depth, ok):
194    return ok
195
196
197def socket_pair():
198    """
199    Establish and return a pair of network sockets connected to each other.
200    """
201    # Connect a pair of sockets
202    port = socket_any_family()
203    port.bind(("", 0))
204    port.listen(1)
205    client = socket(port.family)
206    client.setblocking(False)
207    client.connect_ex((loopback_address(port), port.getsockname()[1]))
208    client.setblocking(True)
209    server = port.accept()[0]
210
211    # Let's pass some unencrypted data to make sure our socket connection is
212    # fine.  Just one byte, so we don't have to worry about buffers getting
213    # filled up or fragmentation.
214    server.send(b"x")
215    assert client.recv(1024) == b"x"
216    client.send(b"y")
217    assert server.recv(1024) == b"y"
218
219    # Most of our callers want non-blocking sockets, make it easy for them.
220    server.setblocking(False)
221    client.setblocking(False)
222
223    return (server, client)
224
225
226def handshake(client, server):
227    conns = [client, server]
228    while conns:
229        for conn in conns:
230            try:
231                conn.do_handshake()
232            except WantReadError:
233                pass
234            else:
235                conns.remove(conn)
236
237
238def _create_certificate_chain():
239    """
240    Construct and return a chain of certificates.
241
242        1. A new self-signed certificate authority certificate (cacert)
243        2. A new intermediate certificate signed by cacert (icert)
244        3. A new server certificate signed by icert (scert)
245    """
246    caext = X509Extension(b"basicConstraints", False, b"CA:true")
247    not_after_date = datetime.date.today() + datetime.timedelta(days=365)
248    not_after = not_after_date.strftime("%Y%m%d%H%M%SZ").encode("ascii")
249
250    # Step 1
251    cakey = PKey()
252    cakey.generate_key(TYPE_RSA, 2048)
253    cacert = X509()
254    cacert.set_version(2)
255    cacert.get_subject().commonName = "Authority Certificate"
256    cacert.set_issuer(cacert.get_subject())
257    cacert.set_pubkey(cakey)
258    cacert.set_notBefore(b"20000101000000Z")
259    cacert.set_notAfter(not_after)
260    cacert.add_extensions([caext])
261    cacert.set_serial_number(0)
262    cacert.sign(cakey, "sha256")
263
264    # Step 2
265    ikey = PKey()
266    ikey.generate_key(TYPE_RSA, 2048)
267    icert = X509()
268    icert.set_version(2)
269    icert.get_subject().commonName = "Intermediate Certificate"
270    icert.set_issuer(cacert.get_subject())
271    icert.set_pubkey(ikey)
272    icert.set_notBefore(b"20000101000000Z")
273    icert.set_notAfter(not_after)
274    icert.add_extensions([caext])
275    icert.set_serial_number(0)
276    icert.sign(cakey, "sha256")
277
278    # Step 3
279    skey = PKey()
280    skey.generate_key(TYPE_RSA, 2048)
281    scert = X509()
282    scert.set_version(2)
283    scert.get_subject().commonName = "Server Certificate"
284    scert.set_issuer(icert.get_subject())
285    scert.set_pubkey(skey)
286    scert.set_notBefore(b"20000101000000Z")
287    scert.set_notAfter(not_after)
288    scert.add_extensions(
289        [X509Extension(b"basicConstraints", True, b"CA:false")]
290    )
291    scert.set_serial_number(0)
292    scert.sign(ikey, "sha256")
293
294    return [(cakey, cacert), (ikey, icert), (skey, scert)]
295
296
297def loopback_client_factory(socket, version=SSLv23_METHOD):
298    client = Connection(Context(version), socket)
299    client.set_connect_state()
300    return client
301
302
303def loopback_server_factory(socket, version=SSLv23_METHOD):
304    ctx = Context(version)
305    ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
306    ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
307    server = Connection(ctx, socket)
308    server.set_accept_state()
309    return server
310
311
312def loopback(server_factory=None, client_factory=None):
313    """
314    Create a connected socket pair and force two connected SSL sockets
315    to talk to each other via memory BIOs.
316    """
317    if server_factory is None:
318        server_factory = loopback_server_factory
319    if client_factory is None:
320        client_factory = loopback_client_factory
321
322    (server, client) = socket_pair()
323    server = server_factory(server)
324    client = client_factory(client)
325
326    handshake(client, server)
327
328    server.setblocking(True)
329    client.setblocking(True)
330    return server, client
331
332
333def interact_in_memory(client_conn, server_conn):
334    """
335    Try to read application bytes from each of the two `Connection` objects.
336    Copy bytes back and forth between their send/receive buffers for as long
337    as there is anything to copy.  When there is nothing more to copy,
338    return `None`.  If one of them actually manages to deliver some application
339    bytes, return a two-tuple of the connection from which the bytes were read
340    and the bytes themselves.
341    """
342    wrote = True
343    while wrote:
344        # Loop until neither side has anything to say
345        wrote = False
346
347        # Copy stuff from each side's send buffer to the other side's
348        # receive buffer.
349        for (read, write) in [
350            (client_conn, server_conn),
351            (server_conn, client_conn),
352        ]:
353
354            # Give the side a chance to generate some more bytes, or succeed.
355            try:
356                data = read.recv(2 ** 16)
357            except WantReadError:
358                # It didn't succeed, so we'll hope it generated some output.
359                pass
360            else:
361                # It did succeed, so we'll stop now and let the caller deal
362                # with it.
363                return (read, data)
364
365            while True:
366                # Keep copying as long as there's more stuff there.
367                try:
368                    dirty = read.bio_read(4096)
369                except WantReadError:
370                    # Okay, nothing more waiting to be sent.  Stop
371                    # processing this send buffer.
372                    break
373                else:
374                    # Keep track of the fact that someone generated some
375                    # output.
376                    wrote = True
377                    write.bio_write(dirty)
378
379
380def handshake_in_memory(client_conn, server_conn):
381    """
382    Perform the TLS handshake between two `Connection` instances connected to
383    each other via memory BIOs.
384    """
385    client_conn.set_connect_state()
386    server_conn.set_accept_state()
387
388    for conn in [client_conn, server_conn]:
389        try:
390            conn.do_handshake()
391        except WantReadError:
392            pass
393
394    interact_in_memory(client_conn, server_conn)
395
396
397class TestVersion(object):
398    """
399    Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and
400    `OpenSSL.SSL.OPENSSL_VERSION_NUMBER`.
401    """
402
403    def test_OPENSSL_VERSION_NUMBER(self):
404        """
405        `OPENSSL_VERSION_NUMBER` is an integer with status in the low byte and
406        the patch, fix, minor, and major versions in the nibbles above that.
407        """
408        assert isinstance(OPENSSL_VERSION_NUMBER, int)
409
410    def test_SSLeay_version(self):
411        """
412        `SSLeay_version` takes a version type indicator and returns one of a
413        number of version strings based on that indicator.
414        """
415        versions = {}
416        for t in [
417            SSLEAY_VERSION,
418            SSLEAY_CFLAGS,
419            SSLEAY_BUILT_ON,
420            SSLEAY_PLATFORM,
421            SSLEAY_DIR,
422        ]:
423            version = SSLeay_version(t)
424            versions[version] = t
425            assert isinstance(version, bytes)
426        assert len(versions) == 5
427
428
429@pytest.fixture
430def ca_file(tmpdir):
431    """
432    Create a valid PEM file with CA certificates and return the path.
433    """
434    key = rsa.generate_private_key(
435        public_exponent=65537, key_size=2048, backend=default_backend()
436    )
437    public_key = key.public_key()
438
439    builder = x509.CertificateBuilder()
440    builder = builder.subject_name(
441        x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")])
442    )
443    builder = builder.issuer_name(
444        x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")])
445    )
446    one_day = datetime.timedelta(1, 0, 0)
447    builder = builder.not_valid_before(datetime.datetime.today() - one_day)
448    builder = builder.not_valid_after(datetime.datetime.today() + one_day)
449    builder = builder.serial_number(int(uuid.uuid4()))
450    builder = builder.public_key(public_key)
451    builder = builder.add_extension(
452        x509.BasicConstraints(ca=True, path_length=None),
453        critical=True,
454    )
455
456    certificate = builder.sign(
457        private_key=key, algorithm=hashes.SHA256(), backend=default_backend()
458    )
459
460    ca_file = tmpdir.join("test.pem")
461    ca_file.write_binary(
462        certificate.public_bytes(
463            encoding=serialization.Encoding.PEM,
464        )
465    )
466
467    return str(ca_file).encode("ascii")
468
469
470@pytest.fixture
471def context():
472    """
473    A simple "best TLS you can get" context. TLS 1.2+ in any reasonable OpenSSL
474    """
475    return Context(SSLv23_METHOD)
476
477
478class TestContext(object):
479    """
480    Unit tests for `OpenSSL.SSL.Context`.
481    """
482
483    @pytest.mark.parametrize(
484        "cipher_string",
485        [b"hello world:AES128-SHA", u"hello world:AES128-SHA"],
486    )
487    def test_set_cipher_list(self, context, cipher_string):
488        """
489        `Context.set_cipher_list` accepts both byte and unicode strings
490        for naming the ciphers which connections created with the context
491        object will be able to choose from.
492        """
493        context.set_cipher_list(cipher_string)
494        conn = Connection(context, None)
495
496        assert "AES128-SHA" in conn.get_cipher_list()
497
498    def test_set_cipher_list_wrong_type(self, context):
499        """
500        `Context.set_cipher_list` raises `TypeError` when passed a non-string
501        argument.
502        """
503        with pytest.raises(TypeError):
504            context.set_cipher_list(object())
505
506    @flaky.flaky
507    def test_set_cipher_list_no_cipher_match(self, context):
508        """
509        `Context.set_cipher_list` raises `OpenSSL.SSL.Error` with a
510        `"no cipher match"` reason string regardless of the TLS
511        version.
512        """
513        with pytest.raises(Error) as excinfo:
514            context.set_cipher_list(b"imaginary-cipher")
515        assert excinfo.value.args == (
516            [
517                (
518                    "SSL routines",
519                    "SSL_CTX_set_cipher_list",
520                    "no cipher match",
521                )
522            ],
523        )
524
525    def test_load_client_ca(self, context, ca_file):
526        """
527        `Context.load_client_ca` works as far as we can tell.
528        """
529        context.load_client_ca(ca_file)
530
531    def test_load_client_ca_invalid(self, context, tmpdir):
532        """
533        `Context.load_client_ca` raises an Error if the ca file is invalid.
534        """
535        ca_file = tmpdir.join("test.pem")
536        ca_file.write("")
537
538        with pytest.raises(Error) as e:
539            context.load_client_ca(str(ca_file).encode("ascii"))
540
541        assert "PEM routines" == e.value.args[0][0][0]
542
543    def test_load_client_ca_unicode(self, context, ca_file):
544        """
545        Passing the path as unicode raises a warning but works.
546        """
547        pytest.deprecated_call(context.load_client_ca, ca_file.decode("ascii"))
548
549    def test_set_session_id(self, context):
550        """
551        `Context.set_session_id` works as far as we can tell.
552        """
553        context.set_session_id(b"abc")
554
555    def test_set_session_id_fail(self, context):
556        """
557        `Context.set_session_id` errors are propagated.
558        """
559        with pytest.raises(Error) as e:
560            context.set_session_id(b"abc" * 1000)
561
562        assert [
563            (
564                "SSL routines",
565                "SSL_CTX_set_session_id_context",
566                "ssl session id context too long",
567            )
568        ] == e.value.args[0]
569
570    def test_set_session_id_unicode(self, context):
571        """
572        `Context.set_session_id` raises a warning if a unicode string is
573        passed.
574        """
575        pytest.deprecated_call(context.set_session_id, u"abc")
576
577    def test_method(self):
578        """
579        `Context` can be instantiated with one of `SSLv2_METHOD`,
580        `SSLv3_METHOD`, `SSLv23_METHOD`, `TLSv1_METHOD`, `TLSv1_1_METHOD`,
581        or `TLSv1_2_METHOD`.
582        """
583        methods = [SSLv23_METHOD, TLSv1_METHOD]
584        for meth in methods:
585            Context(meth)
586
587        maybe = [SSLv2_METHOD, SSLv3_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD]
588        for meth in maybe:
589            try:
590                Context(meth)
591            except (Error, ValueError):
592                # Some versions of OpenSSL have SSLv2 / TLSv1.1 / TLSv1.2, some
593                # don't.  Difficult to say in advance.
594                pass
595
596        with pytest.raises(TypeError):
597            Context("")
598        with pytest.raises(ValueError):
599            Context(10)
600
601    def test_type(self):
602        """
603        `Context` can be used to create instances of that type.
604        """
605        assert is_consistent_type(Context, "Context", TLSv1_METHOD)
606
607    def test_use_privatekey(self):
608        """
609        `Context.use_privatekey` takes an `OpenSSL.crypto.PKey` instance.
610        """
611        key = PKey()
612        key.generate_key(TYPE_RSA, 1024)
613        ctx = Context(SSLv23_METHOD)
614        ctx.use_privatekey(key)
615        with pytest.raises(TypeError):
616            ctx.use_privatekey("")
617
618    def test_use_privatekey_file_missing(self, tmpfile):
619        """
620        `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` when passed
621        the name of a file which does not exist.
622        """
623        ctx = Context(SSLv23_METHOD)
624        with pytest.raises(Error):
625            ctx.use_privatekey_file(tmpfile)
626
627    def _use_privatekey_file_test(self, pemfile, filetype):
628        """
629        Verify that calling ``Context.use_privatekey_file`` with the given
630        arguments does not raise an exception.
631        """
632        key = PKey()
633        key.generate_key(TYPE_RSA, 1024)
634
635        with open(pemfile, "wt") as pem:
636            pem.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii"))
637
638        ctx = Context(SSLv23_METHOD)
639        ctx.use_privatekey_file(pemfile, filetype)
640
641    @pytest.mark.parametrize("filetype", [object(), "", None, 1.0])
642    def test_wrong_privatekey_file_wrong_args(self, tmpfile, filetype):
643        """
644        `Context.use_privatekey_file` raises `TypeError` when called with
645        a `filetype` which is not a valid file encoding.
646        """
647        ctx = Context(SSLv23_METHOD)
648        with pytest.raises(TypeError):
649            ctx.use_privatekey_file(tmpfile, filetype)
650
651    def test_use_privatekey_file_bytes(self, tmpfile):
652        """
653        A private key can be specified from a file by passing a ``bytes``
654        instance giving the file name to ``Context.use_privatekey_file``.
655        """
656        self._use_privatekey_file_test(
657            tmpfile + NON_ASCII.encode(getfilesystemencoding()),
658            FILETYPE_PEM,
659        )
660
661    def test_use_privatekey_file_unicode(self, tmpfile):
662        """
663        A private key can be specified from a file by passing a ``unicode``
664        instance giving the file name to ``Context.use_privatekey_file``.
665        """
666        self._use_privatekey_file_test(
667            tmpfile.decode(getfilesystemencoding()) + NON_ASCII,
668            FILETYPE_PEM,
669        )
670
671    def test_use_certificate_wrong_args(self):
672        """
673        `Context.use_certificate_wrong_args` raises `TypeError` when not passed
674        exactly one `OpenSSL.crypto.X509` instance as an argument.
675        """
676        ctx = Context(SSLv23_METHOD)
677        with pytest.raises(TypeError):
678            ctx.use_certificate("hello, world")
679
680    def test_use_certificate_uninitialized(self):
681        """
682        `Context.use_certificate` raises `OpenSSL.SSL.Error` when passed a
683        `OpenSSL.crypto.X509` instance which has not been initialized
684        (ie, which does not actually have any certificate data).
685        """
686        ctx = Context(SSLv23_METHOD)
687        with pytest.raises(Error):
688            ctx.use_certificate(X509())
689
690    def test_use_certificate(self):
691        """
692        `Context.use_certificate` sets the certificate which will be
693        used to identify connections created using the context.
694        """
695        # TODO
696        # Hard to assert anything.  But we could set a privatekey then ask
697        # OpenSSL if the cert and key agree using check_privatekey.  Then as
698        # long as check_privatekey works right we're good...
699        ctx = Context(SSLv23_METHOD)
700        ctx.use_certificate(load_certificate(FILETYPE_PEM, root_cert_pem))
701
702    def test_use_certificate_file_wrong_args(self):
703        """
704        `Context.use_certificate_file` raises `TypeError` if the first
705        argument is not a byte string or the second argument is not an integer.
706        """
707        ctx = Context(SSLv23_METHOD)
708        with pytest.raises(TypeError):
709            ctx.use_certificate_file(object(), FILETYPE_PEM)
710        with pytest.raises(TypeError):
711            ctx.use_certificate_file(b"somefile", object())
712        with pytest.raises(TypeError):
713            ctx.use_certificate_file(object(), FILETYPE_PEM)
714
715    def test_use_certificate_file_missing(self, tmpfile):
716        """
717        `Context.use_certificate_file` raises `OpenSSL.SSL.Error` if passed
718        the name of a file which does not exist.
719        """
720        ctx = Context(SSLv23_METHOD)
721        with pytest.raises(Error):
722            ctx.use_certificate_file(tmpfile)
723
724    def _use_certificate_file_test(self, certificate_file):
725        """
726        Verify that calling ``Context.use_certificate_file`` with the given
727        filename doesn't raise an exception.
728        """
729        # TODO
730        # Hard to assert anything.  But we could set a privatekey then ask
731        # OpenSSL if the cert and key agree using check_privatekey.  Then as
732        # long as check_privatekey works right we're good...
733        with open(certificate_file, "wb") as pem_file:
734            pem_file.write(root_cert_pem)
735
736        ctx = Context(SSLv23_METHOD)
737        ctx.use_certificate_file(certificate_file)
738
739    def test_use_certificate_file_bytes(self, tmpfile):
740        """
741        `Context.use_certificate_file` sets the certificate (given as a
742        `bytes` filename) which will be used to identify connections created
743        using the context.
744        """
745        filename = tmpfile + NON_ASCII.encode(getfilesystemencoding())
746        self._use_certificate_file_test(filename)
747
748    def test_use_certificate_file_unicode(self, tmpfile):
749        """
750        `Context.use_certificate_file` sets the certificate (given as a
751        `bytes` filename) which will be used to identify connections created
752        using the context.
753        """
754        filename = tmpfile.decode(getfilesystemencoding()) + NON_ASCII
755        self._use_certificate_file_test(filename)
756
757    def test_check_privatekey_valid(self):
758        """
759        `Context.check_privatekey` returns `None` if the `Context` instance
760        has been configured to use a matched key and certificate pair.
761        """
762        key = load_privatekey(FILETYPE_PEM, client_key_pem)
763        cert = load_certificate(FILETYPE_PEM, client_cert_pem)
764        context = Context(SSLv23_METHOD)
765        context.use_privatekey(key)
766        context.use_certificate(cert)
767        assert None is context.check_privatekey()
768
769    def test_check_privatekey_invalid(self):
770        """
771        `Context.check_privatekey` raises `Error` if the `Context` instance
772        has been configured to use a key and certificate pair which don't
773        relate to each other.
774        """
775        key = load_privatekey(FILETYPE_PEM, client_key_pem)
776        cert = load_certificate(FILETYPE_PEM, server_cert_pem)
777        context = Context(SSLv23_METHOD)
778        context.use_privatekey(key)
779        context.use_certificate(cert)
780        with pytest.raises(Error):
781            context.check_privatekey()
782
783    def test_app_data(self):
784        """
785        `Context.set_app_data` stores an object for later retrieval
786        using `Context.get_app_data`.
787        """
788        app_data = object()
789        context = Context(SSLv23_METHOD)
790        context.set_app_data(app_data)
791        assert context.get_app_data() is app_data
792
793    def test_set_options_wrong_args(self):
794        """
795        `Context.set_options` raises `TypeError` if called with
796        a non-`int` argument.
797        """
798        context = Context(SSLv23_METHOD)
799        with pytest.raises(TypeError):
800            context.set_options(None)
801
802    def test_set_options(self):
803        """
804        `Context.set_options` returns the new options value.
805        """
806        context = Context(SSLv23_METHOD)
807        options = context.set_options(OP_NO_SSLv2)
808        assert options & OP_NO_SSLv2 == OP_NO_SSLv2
809
810    def test_set_mode_wrong_args(self):
811        """
812        `Context.set_mode` raises `TypeError` if called with
813        a non-`int` argument.
814        """
815        context = Context(SSLv23_METHOD)
816        with pytest.raises(TypeError):
817            context.set_mode(None)
818
819    def test_set_mode(self):
820        """
821        `Context.set_mode` accepts a mode bitvector and returns the
822        newly set mode.
823        """
824        context = Context(SSLv23_METHOD)
825        assert MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS)
826
827    def test_set_timeout_wrong_args(self):
828        """
829        `Context.set_timeout` raises `TypeError` if called with
830        a non-`int` argument.
831        """
832        context = Context(SSLv23_METHOD)
833        with pytest.raises(TypeError):
834            context.set_timeout(None)
835
836    def test_timeout(self):
837        """
838        `Context.set_timeout` sets the session timeout for all connections
839        created using the context object. `Context.get_timeout` retrieves
840        this value.
841        """
842        context = Context(SSLv23_METHOD)
843        context.set_timeout(1234)
844        assert context.get_timeout() == 1234
845
846    def test_set_verify_depth_wrong_args(self):
847        """
848        `Context.set_verify_depth` raises `TypeError` if called with a
849        non-`int` argument.
850        """
851        context = Context(SSLv23_METHOD)
852        with pytest.raises(TypeError):
853            context.set_verify_depth(None)
854
855    def test_verify_depth(self):
856        """
857        `Context.set_verify_depth` sets the number of certificates in
858        a chain to follow before giving up.  The value can be retrieved with
859        `Context.get_verify_depth`.
860        """
861        context = Context(SSLv23_METHOD)
862        context.set_verify_depth(11)
863        assert context.get_verify_depth() == 11
864
865    def _write_encrypted_pem(self, passphrase, tmpfile):
866        """
867        Write a new private key out to a new file, encrypted using the given
868        passphrase.  Return the path to the new file.
869        """
870        key = PKey()
871        key.generate_key(TYPE_RSA, 1024)
872        pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase)
873        with open(tmpfile, "w") as fObj:
874            fObj.write(pem.decode("ascii"))
875        return tmpfile
876
877    def test_set_passwd_cb_wrong_args(self):
878        """
879        `Context.set_passwd_cb` raises `TypeError` if called with a
880        non-callable first argument.
881        """
882        context = Context(SSLv23_METHOD)
883        with pytest.raises(TypeError):
884            context.set_passwd_cb(None)
885
886    def test_set_passwd_cb(self, tmpfile):
887        """
888        `Context.set_passwd_cb` accepts a callable which will be invoked when
889        a private key is loaded from an encrypted PEM.
890        """
891        passphrase = b"foobar"
892        pemFile = self._write_encrypted_pem(passphrase, tmpfile)
893        calledWith = []
894
895        def passphraseCallback(maxlen, verify, extra):
896            calledWith.append((maxlen, verify, extra))
897            return passphrase
898
899        context = Context(SSLv23_METHOD)
900        context.set_passwd_cb(passphraseCallback)
901        context.use_privatekey_file(pemFile)
902        assert len(calledWith) == 1
903        assert isinstance(calledWith[0][0], int)
904        assert isinstance(calledWith[0][1], int)
905        assert calledWith[0][2] is None
906
907    def test_passwd_callback_exception(self, tmpfile):
908        """
909        `Context.use_privatekey_file` propagates any exception raised
910        by the passphrase callback.
911        """
912        pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
913
914        def passphraseCallback(maxlen, verify, extra):
915            raise RuntimeError("Sorry, I am a fail.")
916
917        context = Context(SSLv23_METHOD)
918        context.set_passwd_cb(passphraseCallback)
919        with pytest.raises(RuntimeError):
920            context.use_privatekey_file(pemFile)
921
922    def test_passwd_callback_false(self, tmpfile):
923        """
924        `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the
925        passphrase callback returns a false value.
926        """
927        pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
928
929        def passphraseCallback(maxlen, verify, extra):
930            return b""
931
932        context = Context(SSLv23_METHOD)
933        context.set_passwd_cb(passphraseCallback)
934        with pytest.raises(Error):
935            context.use_privatekey_file(pemFile)
936
937    def test_passwd_callback_non_string(self, tmpfile):
938        """
939        `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the
940        passphrase callback returns a true non-string value.
941        """
942        pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
943
944        def passphraseCallback(maxlen, verify, extra):
945            return 10
946
947        context = Context(SSLv23_METHOD)
948        context.set_passwd_cb(passphraseCallback)
949        # TODO: Surely this is the wrong error?
950        with pytest.raises(ValueError):
951            context.use_privatekey_file(pemFile)
952
953    def test_passwd_callback_too_long(self, tmpfile):
954        """
955        If the passphrase returned by the passphrase callback returns a string
956        longer than the indicated maximum length, it is truncated.
957        """
958        # A priori knowledge!
959        passphrase = b"x" * 1024
960        pemFile = self._write_encrypted_pem(passphrase, tmpfile)
961
962        def passphraseCallback(maxlen, verify, extra):
963            assert maxlen == 1024
964            return passphrase + b"y"
965
966        context = Context(SSLv23_METHOD)
967        context.set_passwd_cb(passphraseCallback)
968        # This shall succeed because the truncated result is the correct
969        # passphrase.
970        context.use_privatekey_file(pemFile)
971
972    def test_set_info_callback(self):
973        """
974        `Context.set_info_callback` accepts a callable which will be
975        invoked when certain information about an SSL connection is available.
976        """
977        (server, client) = socket_pair()
978
979        clientSSL = Connection(Context(SSLv23_METHOD), client)
980        clientSSL.set_connect_state()
981
982        called = []
983
984        def info(conn, where, ret):
985            called.append((conn, where, ret))
986
987        context = Context(SSLv23_METHOD)
988        context.set_info_callback(info)
989        context.use_certificate(load_certificate(FILETYPE_PEM, root_cert_pem))
990        context.use_privatekey(load_privatekey(FILETYPE_PEM, root_key_pem))
991
992        serverSSL = Connection(context, server)
993        serverSSL.set_accept_state()
994
995        handshake(clientSSL, serverSSL)
996
997        # The callback must always be called with a Connection instance as the
998        # first argument.  It would probably be better to split this into
999        # separate tests for client and server side info callbacks so we could
1000        # assert it is called with the right Connection instance.  It would
1001        # also be good to assert *something* about `where` and `ret`.
1002        notConnections = [
1003            conn
1004            for (conn, where, ret) in called
1005            if not isinstance(conn, Connection)
1006        ]
1007        assert (
1008            [] == notConnections
1009        ), "Some info callback arguments were not Connection instances."
1010
1011    @pytest.mark.skipif(
1012        not getattr(_lib, "Cryptography_HAS_KEYLOG", None),
1013        reason="SSL_CTX_set_keylog_callback unavailable",
1014    )
1015    def test_set_keylog_callback(self):
1016        """
1017        `Context.set_keylog_callback` accepts a callable which will be
1018        invoked when key material is generated or received.
1019        """
1020        called = []
1021
1022        def keylog(conn, line):
1023            called.append((conn, line))
1024
1025        server_context = Context(TLSv1_2_METHOD)
1026        server_context.set_keylog_callback(keylog)
1027        server_context.use_certificate(
1028            load_certificate(FILETYPE_PEM, root_cert_pem)
1029        )
1030        server_context.use_privatekey(
1031            load_privatekey(FILETYPE_PEM, root_key_pem)
1032        )
1033
1034        client_context = Context(SSLv23_METHOD)
1035
1036        self._handshake_test(server_context, client_context)
1037
1038        assert called
1039        assert all(isinstance(conn, Connection) for conn, line in called)
1040        assert all(b"CLIENT_RANDOM" in line for conn, line in called)
1041
1042    def _load_verify_locations_test(self, *args):
1043        """
1044        Create a client context which will verify the peer certificate and call
1045        its `load_verify_locations` method with the given arguments.
1046        Then connect it to a server and ensure that the handshake succeeds.
1047        """
1048        (server, client) = socket_pair()
1049
1050        clientContext = Context(SSLv23_METHOD)
1051        clientContext.load_verify_locations(*args)
1052        # Require that the server certificate verify properly or the
1053        # connection will fail.
1054        clientContext.set_verify(
1055            VERIFY_PEER,
1056            lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
1057        )
1058
1059        clientSSL = Connection(clientContext, client)
1060        clientSSL.set_connect_state()
1061
1062        serverContext = Context(SSLv23_METHOD)
1063        serverContext.use_certificate(
1064            load_certificate(FILETYPE_PEM, root_cert_pem)
1065        )
1066        serverContext.use_privatekey(
1067            load_privatekey(FILETYPE_PEM, root_key_pem)
1068        )
1069
1070        serverSSL = Connection(serverContext, server)
1071        serverSSL.set_accept_state()
1072
1073        # Without load_verify_locations above, the handshake
1074        # will fail:
1075        # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE',
1076        #          'certificate verify failed')]
1077        handshake(clientSSL, serverSSL)
1078
1079        cert = clientSSL.get_peer_certificate()
1080        assert cert.get_subject().CN == "Testing Root CA"
1081
1082    def _load_verify_cafile(self, cafile):
1083        """
1084        Verify that if path to a file containing a certificate is passed to
1085        `Context.load_verify_locations` for the ``cafile`` parameter, that
1086        certificate is used as a trust root for the purposes of verifying
1087        connections created using that `Context`.
1088        """
1089        with open(cafile, "w") as fObj:
1090            fObj.write(root_cert_pem.decode("ascii"))
1091
1092        self._load_verify_locations_test(cafile)
1093
1094    def test_load_verify_bytes_cafile(self, tmpfile):
1095        """
1096        `Context.load_verify_locations` accepts a file name as a `bytes`
1097        instance and uses the certificates within for verification purposes.
1098        """
1099        cafile = tmpfile + NON_ASCII.encode(getfilesystemencoding())
1100        self._load_verify_cafile(cafile)
1101
1102    def test_load_verify_unicode_cafile(self, tmpfile):
1103        """
1104        `Context.load_verify_locations` accepts a file name as a `unicode`
1105        instance and uses the certificates within for verification purposes.
1106        """
1107        self._load_verify_cafile(
1108            tmpfile.decode(getfilesystemencoding()) + NON_ASCII
1109        )
1110
1111    def test_load_verify_invalid_file(self, tmpfile):
1112        """
1113        `Context.load_verify_locations` raises `Error` when passed a
1114        non-existent cafile.
1115        """
1116        clientContext = Context(SSLv23_METHOD)
1117        with pytest.raises(Error):
1118            clientContext.load_verify_locations(tmpfile)
1119
1120    def _load_verify_directory_locations_capath(self, capath):
1121        """
1122        Verify that if path to a directory containing certificate files is
1123        passed to ``Context.load_verify_locations`` for the ``capath``
1124        parameter, those certificates are used as trust roots for the purposes
1125        of verifying connections created using that ``Context``.
1126        """
1127        makedirs(capath)
1128        # Hash values computed manually with c_rehash to avoid depending on
1129        # c_rehash in the test suite.  One is from OpenSSL 0.9.8, the other
1130        # from OpenSSL 1.0.0.
1131        for name in [b"c7adac82.0", b"c3705638.0"]:
1132            cafile = join_bytes_or_unicode(capath, name)
1133            with open(cafile, "w") as fObj:
1134                fObj.write(root_cert_pem.decode("ascii"))
1135
1136        self._load_verify_locations_test(None, capath)
1137
1138    def test_load_verify_directory_bytes_capath(self, tmpfile):
1139        """
1140        `Context.load_verify_locations` accepts a directory name as a `bytes`
1141        instance and uses the certificates within for verification purposes.
1142        """
1143        self._load_verify_directory_locations_capath(
1144            tmpfile + NON_ASCII.encode(getfilesystemencoding())
1145        )
1146
1147    def test_load_verify_directory_unicode_capath(self, tmpfile):
1148        """
1149        `Context.load_verify_locations` accepts a directory name as a `unicode`
1150        instance and uses the certificates within for verification purposes.
1151        """
1152        self._load_verify_directory_locations_capath(
1153            tmpfile.decode(getfilesystemencoding()) + NON_ASCII
1154        )
1155
1156    def test_load_verify_locations_wrong_args(self):
1157        """
1158        `Context.load_verify_locations` raises `TypeError` if with non-`str`
1159        arguments.
1160        """
1161        context = Context(SSLv23_METHOD)
1162        with pytest.raises(TypeError):
1163            context.load_verify_locations(object())
1164        with pytest.raises(TypeError):
1165            context.load_verify_locations(object(), object())
1166
1167    @pytest.mark.skipif(
1168        not platform.startswith("linux"),
1169        reason="Loading fallback paths is a linux-specific behavior to "
1170        "accommodate pyca/cryptography manylinux1 wheels",
1171    )
1172    def test_fallback_default_verify_paths(self, monkeypatch):
1173        """
1174        Test that we load certificates successfully on linux from the fallback
1175        path. To do this we set the _CRYPTOGRAPHY_MANYLINUX1_CA_FILE and
1176        _CRYPTOGRAPHY_MANYLINUX1_CA_DIR vars to be equal to whatever the
1177        current OpenSSL default is and we disable
1178        SSL_CTX_SET_default_verify_paths so that it can't find certs unless
1179        it loads via fallback.
1180        """
1181        context = Context(SSLv23_METHOD)
1182        monkeypatch.setattr(
1183            _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1
1184        )
1185        monkeypatch.setattr(
1186            SSL,
1187            "_CRYPTOGRAPHY_MANYLINUX1_CA_FILE",
1188            _ffi.string(_lib.X509_get_default_cert_file()),
1189        )
1190        monkeypatch.setattr(
1191            SSL,
1192            "_CRYPTOGRAPHY_MANYLINUX1_CA_DIR",
1193            _ffi.string(_lib.X509_get_default_cert_dir()),
1194        )
1195        context.set_default_verify_paths()
1196        store = context.get_cert_store()
1197        sk_obj = _lib.X509_STORE_get0_objects(store._store)
1198        assert sk_obj != _ffi.NULL
1199        num = _lib.sk_X509_OBJECT_num(sk_obj)
1200        assert num != 0
1201
1202    def test_check_env_vars(self, monkeypatch):
1203        """
1204        Test that we return True/False appropriately if the env vars are set.
1205        """
1206        context = Context(SSLv23_METHOD)
1207        dir_var = "CUSTOM_DIR_VAR"
1208        file_var = "CUSTOM_FILE_VAR"
1209        assert context._check_env_vars_set(dir_var, file_var) is False
1210        monkeypatch.setenv(dir_var, "value")
1211        monkeypatch.setenv(file_var, "value")
1212        assert context._check_env_vars_set(dir_var, file_var) is True
1213        assert context._check_env_vars_set(dir_var, file_var) is True
1214
1215    def test_verify_no_fallback_if_env_vars_set(self, monkeypatch):
1216        """
1217        Test that we don't use the fallback path if env vars are set.
1218        """
1219        context = Context(SSLv23_METHOD)
1220        monkeypatch.setattr(
1221            _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1
1222        )
1223        dir_env_var = _ffi.string(_lib.X509_get_default_cert_dir_env()).decode(
1224            "ascii"
1225        )
1226        file_env_var = _ffi.string(
1227            _lib.X509_get_default_cert_file_env()
1228        ).decode("ascii")
1229        monkeypatch.setenv(dir_env_var, "value")
1230        monkeypatch.setenv(file_env_var, "value")
1231        context.set_default_verify_paths()
1232
1233        monkeypatch.setattr(
1234            context, "_fallback_default_verify_paths", raiser(SystemError)
1235        )
1236        context.set_default_verify_paths()
1237
1238    @pytest.mark.skipif(
1239        platform == "win32",
1240        reason="set_default_verify_paths appears not to work on Windows.  "
1241        "See LP#404343 and LP#404344.",
1242    )
1243    def test_set_default_verify_paths(self):
1244        """
1245        `Context.set_default_verify_paths` causes the platform-specific CA
1246        certificate locations to be used for verification purposes.
1247        """
1248        # Testing this requires a server with a certificate signed by one
1249        # of the CAs in the platform CA location.  Getting one of those
1250        # costs money.  Fortunately (or unfortunately, depending on your
1251        # perspective), it's easy to think of a public server on the
1252        # internet which has such a certificate.  Connecting to the network
1253        # in a unit test is bad, but it's the only way I can think of to
1254        # really test this. -exarkun
1255        context = Context(SSLv23_METHOD)
1256        context.set_default_verify_paths()
1257        context.set_verify(
1258            VERIFY_PEER,
1259            lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
1260        )
1261
1262        client = socket_any_family()
1263        client.connect(("encrypted.google.com", 443))
1264        clientSSL = Connection(context, client)
1265        clientSSL.set_connect_state()
1266        clientSSL.set_tlsext_host_name(b"encrypted.google.com")
1267        clientSSL.do_handshake()
1268        clientSSL.send(b"GET / HTTP/1.0\r\n\r\n")
1269        assert clientSSL.recv(1024)
1270
1271    def test_fallback_path_is_not_file_or_dir(self):
1272        """
1273        Test that when passed empty arrays or paths that do not exist no
1274        errors are raised.
1275        """
1276        context = Context(SSLv23_METHOD)
1277        context._fallback_default_verify_paths([], [])
1278        context._fallback_default_verify_paths(["/not/a/file"], ["/not/a/dir"])
1279
1280    def test_add_extra_chain_cert_invalid_cert(self):
1281        """
1282        `Context.add_extra_chain_cert` raises `TypeError` if called with an
1283        object which is not an instance of `X509`.
1284        """
1285        context = Context(SSLv23_METHOD)
1286        with pytest.raises(TypeError):
1287            context.add_extra_chain_cert(object())
1288
1289    def _handshake_test(self, serverContext, clientContext):
1290        """
1291        Verify that a client and server created with the given contexts can
1292        successfully handshake and communicate.
1293        """
1294        serverSocket, clientSocket = socket_pair()
1295
1296        server = Connection(serverContext, serverSocket)
1297        server.set_accept_state()
1298
1299        client = Connection(clientContext, clientSocket)
1300        client.set_connect_state()
1301
1302        # Make them talk to each other.
1303        # interact_in_memory(client, server)
1304        for _ in range(3):
1305            for s in [client, server]:
1306                try:
1307                    s.do_handshake()
1308                except WantReadError:
1309                    pass
1310
1311    def test_set_verify_callback_connection_argument(self):
1312        """
1313        The first argument passed to the verify callback is the
1314        `Connection` instance for which verification is taking place.
1315        """
1316        serverContext = Context(SSLv23_METHOD)
1317        serverContext.use_privatekey(
1318            load_privatekey(FILETYPE_PEM, root_key_pem)
1319        )
1320        serverContext.use_certificate(
1321            load_certificate(FILETYPE_PEM, root_cert_pem)
1322        )
1323        serverConnection = Connection(serverContext, None)
1324
1325        class VerifyCallback(object):
1326            def callback(self, connection, *args):
1327                self.connection = connection
1328                return 1
1329
1330        verify = VerifyCallback()
1331        clientContext = Context(SSLv23_METHOD)
1332        clientContext.set_verify(VERIFY_PEER, verify.callback)
1333        clientConnection = Connection(clientContext, None)
1334        clientConnection.set_connect_state()
1335
1336        handshake_in_memory(clientConnection, serverConnection)
1337
1338        assert verify.connection is clientConnection
1339
1340    def test_x509_in_verify_works(self):
1341        """
1342        We had a bug where the X509 cert instantiated in the callback wrapper
1343        didn't __init__ so it was missing objects needed when calling
1344        get_subject. This test sets up a handshake where we call get_subject
1345        on the cert provided to the verify callback.
1346        """
1347        serverContext = Context(SSLv23_METHOD)
1348        serverContext.use_privatekey(
1349            load_privatekey(FILETYPE_PEM, root_key_pem)
1350        )
1351        serverContext.use_certificate(
1352            load_certificate(FILETYPE_PEM, root_cert_pem)
1353        )
1354        serverConnection = Connection(serverContext, None)
1355
1356        def verify_cb_get_subject(conn, cert, errnum, depth, ok):
1357            assert cert.get_subject()
1358            return 1
1359
1360        clientContext = Context(SSLv23_METHOD)
1361        clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject)
1362        clientConnection = Connection(clientContext, None)
1363        clientConnection.set_connect_state()
1364
1365        handshake_in_memory(clientConnection, serverConnection)
1366
1367    def test_set_verify_callback_exception(self):
1368        """
1369        If the verify callback passed to `Context.set_verify` raises an
1370        exception, verification fails and the exception is propagated to the
1371        caller of `Connection.do_handshake`.
1372        """
1373        serverContext = Context(TLSv1_2_METHOD)
1374        serverContext.use_privatekey(
1375            load_privatekey(FILETYPE_PEM, root_key_pem)
1376        )
1377        serverContext.use_certificate(
1378            load_certificate(FILETYPE_PEM, root_cert_pem)
1379        )
1380
1381        clientContext = Context(TLSv1_2_METHOD)
1382
1383        def verify_callback(*args):
1384            raise Exception("silly verify failure")
1385
1386        clientContext.set_verify(VERIFY_PEER, verify_callback)
1387
1388        with pytest.raises(Exception) as exc:
1389            self._handshake_test(serverContext, clientContext)
1390
1391        assert "silly verify failure" == str(exc.value)
1392
1393    def test_set_verify_callback_reference(self):
1394        """
1395        If the verify callback passed to `Context.set_verify` is set multiple
1396        times, the pointers to the old call functions should not be dangling
1397        and trigger a segfault.
1398        """
1399        serverContext = Context(TLSv1_2_METHOD)
1400        serverContext.use_privatekey(
1401            load_privatekey(FILETYPE_PEM, root_key_pem)
1402        )
1403        serverContext.use_certificate(
1404            load_certificate(FILETYPE_PEM, root_cert_pem)
1405        )
1406
1407        clientContext = Context(TLSv1_2_METHOD)
1408
1409        clients = []
1410
1411        for i in range(5):
1412
1413            def verify_callback(*args):
1414                return True
1415
1416            serverSocket, clientSocket = socket_pair()
1417            client = Connection(clientContext, clientSocket)
1418
1419            clients.append((serverSocket, client))
1420
1421            clientContext.set_verify(VERIFY_PEER, verify_callback)
1422
1423        gc.collect()
1424
1425        # Make them talk to each other.
1426        for serverSocket, client in clients:
1427            server = Connection(serverContext, serverSocket)
1428            server.set_accept_state()
1429            client.set_connect_state()
1430
1431            for _ in range(5):
1432                for s in [client, server]:
1433                    try:
1434                        s.do_handshake()
1435                    except WantReadError:
1436                        pass
1437
1438    @pytest.mark.parametrize("mode", [SSL.VERIFY_PEER, SSL.VERIFY_NONE])
1439    def test_set_verify_default_callback(self, mode):
1440        """
1441        If the verify callback is omitted, the preverify value is used.
1442        """
1443        serverContext = Context(TLSv1_2_METHOD)
1444        serverContext.use_privatekey(
1445            load_privatekey(FILETYPE_PEM, root_key_pem)
1446        )
1447        serverContext.use_certificate(
1448            load_certificate(FILETYPE_PEM, root_cert_pem)
1449        )
1450
1451        clientContext = Context(TLSv1_2_METHOD)
1452        clientContext.set_verify(mode, None)
1453
1454        if mode == SSL.VERIFY_PEER:
1455            with pytest.raises(Exception) as exc:
1456                self._handshake_test(serverContext, clientContext)
1457            assert "certificate verify failed" in str(exc.value)
1458        else:
1459            self._handshake_test(serverContext, clientContext)
1460
1461    def test_add_extra_chain_cert(self, tmpdir):
1462        """
1463        `Context.add_extra_chain_cert` accepts an `X509`
1464        instance to add to the certificate chain.
1465
1466        See `_create_certificate_chain` for the details of the
1467        certificate chain tested.
1468
1469        The chain is tested by starting a server with scert and connecting
1470        to it with a client which trusts cacert and requires verification to
1471        succeed.
1472        """
1473        chain = _create_certificate_chain()
1474        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
1475
1476        # Dump the CA certificate to a file because that's the only way to load
1477        # it as a trusted CA in the client context.
1478        for cert, name in [
1479            (cacert, "ca.pem"),
1480            (icert, "i.pem"),
1481            (scert, "s.pem"),
1482        ]:
1483            with tmpdir.join(name).open("w") as f:
1484                f.write(dump_certificate(FILETYPE_PEM, cert).decode("ascii"))
1485
1486        for key, name in [(cakey, "ca.key"), (ikey, "i.key"), (skey, "s.key")]:
1487            with tmpdir.join(name).open("w") as f:
1488                f.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii"))
1489
1490        # Create the server context
1491        serverContext = Context(SSLv23_METHOD)
1492        serverContext.use_privatekey(skey)
1493        serverContext.use_certificate(scert)
1494        # The client already has cacert, we only need to give them icert.
1495        serverContext.add_extra_chain_cert(icert)
1496
1497        # Create the client
1498        clientContext = Context(SSLv23_METHOD)
1499        clientContext.set_verify(
1500            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb
1501        )
1502        clientContext.load_verify_locations(str(tmpdir.join("ca.pem")))
1503
1504        # Try it out.
1505        self._handshake_test(serverContext, clientContext)
1506
1507    def _use_certificate_chain_file_test(self, certdir):
1508        """
1509        Verify that `Context.use_certificate_chain_file` reads a
1510        certificate chain from a specified file.
1511
1512        The chain is tested by starting a server with scert and connecting to
1513        it with a client which trusts cacert and requires verification to
1514        succeed.
1515        """
1516        chain = _create_certificate_chain()
1517        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
1518
1519        makedirs(certdir)
1520
1521        chainFile = join_bytes_or_unicode(certdir, "chain.pem")
1522        caFile = join_bytes_or_unicode(certdir, "ca.pem")
1523
1524        # Write out the chain file.
1525        with open(chainFile, "wb") as fObj:
1526            # Most specific to least general.
1527            fObj.write(dump_certificate(FILETYPE_PEM, scert))
1528            fObj.write(dump_certificate(FILETYPE_PEM, icert))
1529            fObj.write(dump_certificate(FILETYPE_PEM, cacert))
1530
1531        with open(caFile, "w") as fObj:
1532            fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode("ascii"))
1533
1534        serverContext = Context(SSLv23_METHOD)
1535        serverContext.use_certificate_chain_file(chainFile)
1536        serverContext.use_privatekey(skey)
1537
1538        clientContext = Context(SSLv23_METHOD)
1539        clientContext.set_verify(
1540            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb
1541        )
1542        clientContext.load_verify_locations(caFile)
1543
1544        self._handshake_test(serverContext, clientContext)
1545
1546    def test_use_certificate_chain_file_bytes(self, tmpfile):
1547        """
1548        ``Context.use_certificate_chain_file`` accepts the name of a file (as
1549        an instance of ``bytes``) to specify additional certificates to use to
1550        construct and verify a trust chain.
1551        """
1552        self._use_certificate_chain_file_test(
1553            tmpfile + NON_ASCII.encode(getfilesystemencoding())
1554        )
1555
1556    def test_use_certificate_chain_file_unicode(self, tmpfile):
1557        """
1558        ``Context.use_certificate_chain_file`` accepts the name of a file (as
1559        an instance of ``unicode``) to specify additional certificates to use
1560        to construct and verify a trust chain.
1561        """
1562        self._use_certificate_chain_file_test(
1563            tmpfile.decode(getfilesystemencoding()) + NON_ASCII
1564        )
1565
1566    def test_use_certificate_chain_file_wrong_args(self):
1567        """
1568        `Context.use_certificate_chain_file` raises `TypeError` if passed a
1569        non-byte string single argument.
1570        """
1571        context = Context(SSLv23_METHOD)
1572        with pytest.raises(TypeError):
1573            context.use_certificate_chain_file(object())
1574
1575    def test_use_certificate_chain_file_missing_file(self, tmpfile):
1576        """
1577        `Context.use_certificate_chain_file` raises `OpenSSL.SSL.Error` when
1578        passed a bad chain file name (for example, the name of a file which
1579        does not exist).
1580        """
1581        context = Context(SSLv23_METHOD)
1582        with pytest.raises(Error):
1583            context.use_certificate_chain_file(tmpfile)
1584
1585    def test_set_verify_mode(self):
1586        """
1587        `Context.get_verify_mode` returns the verify mode flags previously
1588        passed to `Context.set_verify`.
1589        """
1590        context = Context(SSLv23_METHOD)
1591        assert context.get_verify_mode() == 0
1592        context.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE)
1593        assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE)
1594
1595    @pytest.mark.parametrize("mode", [None, 1.0, object(), "mode"])
1596    def test_set_verify_wrong_mode_arg(self, mode):
1597        """
1598        `Context.set_verify` raises `TypeError` if the first argument is
1599        not an integer.
1600        """
1601        context = Context(SSLv23_METHOD)
1602        with pytest.raises(TypeError):
1603            context.set_verify(mode=mode)
1604
1605    @pytest.mark.parametrize("callback", [1.0, "mode", ("foo", "bar")])
1606    def test_set_verify_wrong_callable_arg(self, callback):
1607        """
1608        `Context.set_verify` raises `TypeError` if the second argument
1609        is not callable.
1610        """
1611        context = Context(SSLv23_METHOD)
1612        with pytest.raises(TypeError):
1613            context.set_verify(mode=VERIFY_PEER, callback=callback)
1614
1615    def test_load_tmp_dh_wrong_args(self):
1616        """
1617        `Context.load_tmp_dh` raises `TypeError` if called with a
1618        non-`str` argument.
1619        """
1620        context = Context(SSLv23_METHOD)
1621        with pytest.raises(TypeError):
1622            context.load_tmp_dh(object())
1623
1624    def test_load_tmp_dh_missing_file(self):
1625        """
1626        `Context.load_tmp_dh` raises `OpenSSL.SSL.Error` if the
1627        specified file does not exist.
1628        """
1629        context = Context(SSLv23_METHOD)
1630        with pytest.raises(Error):
1631            context.load_tmp_dh(b"hello")
1632
1633    def _load_tmp_dh_test(self, dhfilename):
1634        """
1635        Verify that calling ``Context.load_tmp_dh`` with the given filename
1636        does not raise an exception.
1637        """
1638        context = Context(SSLv23_METHOD)
1639        with open(dhfilename, "w") as dhfile:
1640            dhfile.write(dhparam)
1641
1642        context.load_tmp_dh(dhfilename)
1643
1644    def test_load_tmp_dh_bytes(self, tmpfile):
1645        """
1646        `Context.load_tmp_dh` loads Diffie-Hellman parameters from the
1647        specified file (given as ``bytes``).
1648        """
1649        self._load_tmp_dh_test(
1650            tmpfile + NON_ASCII.encode(getfilesystemencoding()),
1651        )
1652
1653    def test_load_tmp_dh_unicode(self, tmpfile):
1654        """
1655        `Context.load_tmp_dh` loads Diffie-Hellman parameters from the
1656        specified file (given as ``unicode``).
1657        """
1658        self._load_tmp_dh_test(
1659            tmpfile.decode(getfilesystemencoding()) + NON_ASCII,
1660        )
1661
1662    def test_set_tmp_ecdh(self):
1663        """
1664        `Context.set_tmp_ecdh` sets the elliptic curve for Diffie-Hellman to
1665        the specified curve.
1666        """
1667        context = Context(SSLv23_METHOD)
1668        for curve in get_elliptic_curves():
1669            if curve.name.startswith(u"Oakley-"):
1670                # Setting Oakley-EC2N-4 and Oakley-EC2N-3 adds
1671                # ('bignum routines', 'BN_mod_inverse', 'no inverse') to the
1672                # error queue on OpenSSL 1.0.2.
1673                continue
1674            # The only easily "assertable" thing is that it does not raise an
1675            # exception.
1676            context.set_tmp_ecdh(curve)
1677
1678    def test_set_session_cache_mode_wrong_args(self):
1679        """
1680        `Context.set_session_cache_mode` raises `TypeError` if called with
1681        a non-integer argument.
1682        called with other than one integer argument.
1683        """
1684        context = Context(SSLv23_METHOD)
1685        with pytest.raises(TypeError):
1686            context.set_session_cache_mode(object())
1687
1688    def test_session_cache_mode(self):
1689        """
1690        `Context.set_session_cache_mode` specifies how sessions are cached.
1691        The setting can be retrieved via `Context.get_session_cache_mode`.
1692        """
1693        context = Context(SSLv23_METHOD)
1694        context.set_session_cache_mode(SESS_CACHE_OFF)
1695        off = context.set_session_cache_mode(SESS_CACHE_BOTH)
1696        assert SESS_CACHE_OFF == off
1697        assert SESS_CACHE_BOTH == context.get_session_cache_mode()
1698
1699    def test_get_cert_store(self):
1700        """
1701        `Context.get_cert_store` returns a `X509Store` instance.
1702        """
1703        context = Context(SSLv23_METHOD)
1704        store = context.get_cert_store()
1705        assert isinstance(store, X509Store)
1706
1707    def test_set_tlsext_use_srtp_not_bytes(self):
1708        """
1709        `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material.
1710
1711        It raises a TypeError if the list of profiles is not a byte string.
1712        """
1713        context = Context(SSLv23_METHOD)
1714        with pytest.raises(TypeError):
1715            context.set_tlsext_use_srtp(text_type("SRTP_AES128_CM_SHA1_80"))
1716
1717    def test_set_tlsext_use_srtp_invalid_profile(self):
1718        """
1719        `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material.
1720
1721        It raises an Error if the call to OpenSSL fails.
1722        """
1723        context = Context(SSLv23_METHOD)
1724        with pytest.raises(Error):
1725            context.set_tlsext_use_srtp(b"SRTP_BOGUS")
1726
1727    def test_set_tlsext_use_srtp_valid(self):
1728        """
1729        `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material.
1730
1731        It does not return anything.
1732        """
1733        context = Context(SSLv23_METHOD)
1734        assert context.set_tlsext_use_srtp(b"SRTP_AES128_CM_SHA1_80") is None
1735
1736
1737class TestServerNameCallback(object):
1738    """
1739    Tests for `Context.set_tlsext_servername_callback` and its
1740    interaction with `Connection`.
1741    """
1742
1743    def test_old_callback_forgotten(self):
1744        """
1745        If `Context.set_tlsext_servername_callback` is used to specify
1746        a new callback, the one it replaces is dereferenced.
1747        """
1748
1749        def callback(connection):  # pragma: no cover
1750            pass
1751
1752        def replacement(connection):  # pragma: no cover
1753            pass
1754
1755        context = Context(SSLv23_METHOD)
1756        context.set_tlsext_servername_callback(callback)
1757
1758        tracker = ref(callback)
1759        del callback
1760
1761        context.set_tlsext_servername_callback(replacement)
1762
1763        # One run of the garbage collector happens to work on CPython.  PyPy
1764        # doesn't collect the underlying object until a second run for whatever
1765        # reason.  That's fine, it still demonstrates our code has properly
1766        # dropped the reference.
1767        collect()
1768        collect()
1769
1770        callback = tracker()
1771        if callback is not None:
1772            referrers = get_referrers(callback)
1773            if len(referrers) > 1:  # pragma: nocover
1774                pytest.fail("Some references remain: %r" % (referrers,))
1775
1776    def test_no_servername(self):
1777        """
1778        When a client specifies no server name, the callback passed to
1779        `Context.set_tlsext_servername_callback` is invoked and the
1780        result of `Connection.get_servername` is `None`.
1781        """
1782        args = []
1783
1784        def servername(conn):
1785            args.append((conn, conn.get_servername()))
1786
1787        context = Context(SSLv23_METHOD)
1788        context.set_tlsext_servername_callback(servername)
1789
1790        # Lose our reference to it.  The Context is responsible for keeping it
1791        # alive now.
1792        del servername
1793        collect()
1794
1795        # Necessary to actually accept the connection
1796        context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
1797        context.use_certificate(
1798            load_certificate(FILETYPE_PEM, server_cert_pem)
1799        )
1800
1801        # Do a little connection to trigger the logic
1802        server = Connection(context, None)
1803        server.set_accept_state()
1804
1805        client = Connection(Context(SSLv23_METHOD), None)
1806        client.set_connect_state()
1807
1808        interact_in_memory(server, client)
1809
1810        assert args == [(server, None)]
1811
1812    def test_servername(self):
1813        """
1814        When a client specifies a server name in its hello message, the
1815        callback passed to `Contexts.set_tlsext_servername_callback` is
1816        invoked and the result of `Connection.get_servername` is that
1817        server name.
1818        """
1819        args = []
1820
1821        def servername(conn):
1822            args.append((conn, conn.get_servername()))
1823
1824        context = Context(SSLv23_METHOD)
1825        context.set_tlsext_servername_callback(servername)
1826
1827        # Necessary to actually accept the connection
1828        context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
1829        context.use_certificate(
1830            load_certificate(FILETYPE_PEM, server_cert_pem)
1831        )
1832
1833        # Do a little connection to trigger the logic
1834        server = Connection(context, None)
1835        server.set_accept_state()
1836
1837        client = Connection(Context(SSLv23_METHOD), None)
1838        client.set_connect_state()
1839        client.set_tlsext_host_name(b"foo1.example.com")
1840
1841        interact_in_memory(server, client)
1842
1843        assert args == [(server, b"foo1.example.com")]
1844
1845
1846class TestApplicationLayerProtoNegotiation(object):
1847    """
1848    Tests for ALPN in PyOpenSSL.
1849    """
1850
1851    def test_alpn_success(self):
1852        """
1853        Clients and servers that agree on the negotiated ALPN protocol can
1854        correct establish a connection, and the agreed protocol is reported
1855        by the connections.
1856        """
1857        select_args = []
1858
1859        def select(conn, options):
1860            select_args.append((conn, options))
1861            return b"spdy/2"
1862
1863        client_context = Context(SSLv23_METHOD)
1864        client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
1865
1866        server_context = Context(SSLv23_METHOD)
1867        server_context.set_alpn_select_callback(select)
1868
1869        # Necessary to actually accept the connection
1870        server_context.use_privatekey(
1871            load_privatekey(FILETYPE_PEM, server_key_pem)
1872        )
1873        server_context.use_certificate(
1874            load_certificate(FILETYPE_PEM, server_cert_pem)
1875        )
1876
1877        # Do a little connection to trigger the logic
1878        server = Connection(server_context, None)
1879        server.set_accept_state()
1880
1881        client = Connection(client_context, None)
1882        client.set_connect_state()
1883
1884        interact_in_memory(server, client)
1885
1886        assert select_args == [(server, [b"http/1.1", b"spdy/2"])]
1887
1888        assert server.get_alpn_proto_negotiated() == b"spdy/2"
1889        assert client.get_alpn_proto_negotiated() == b"spdy/2"
1890
1891    def test_alpn_set_on_connection(self):
1892        """
1893        The same as test_alpn_success, but setting the ALPN protocols on
1894        the connection rather than the context.
1895        """
1896        select_args = []
1897
1898        def select(conn, options):
1899            select_args.append((conn, options))
1900            return b"spdy/2"
1901
1902        # Setup the client context but don't set any ALPN protocols.
1903        client_context = Context(SSLv23_METHOD)
1904
1905        server_context = Context(SSLv23_METHOD)
1906        server_context.set_alpn_select_callback(select)
1907
1908        # Necessary to actually accept the connection
1909        server_context.use_privatekey(
1910            load_privatekey(FILETYPE_PEM, server_key_pem)
1911        )
1912        server_context.use_certificate(
1913            load_certificate(FILETYPE_PEM, server_cert_pem)
1914        )
1915
1916        # Do a little connection to trigger the logic
1917        server = Connection(server_context, None)
1918        server.set_accept_state()
1919
1920        # Set the ALPN protocols on the client connection.
1921        client = Connection(client_context, None)
1922        client.set_alpn_protos([b"http/1.1", b"spdy/2"])
1923        client.set_connect_state()
1924
1925        interact_in_memory(server, client)
1926
1927        assert select_args == [(server, [b"http/1.1", b"spdy/2"])]
1928
1929        assert server.get_alpn_proto_negotiated() == b"spdy/2"
1930        assert client.get_alpn_proto_negotiated() == b"spdy/2"
1931
1932    def test_alpn_server_fail(self):
1933        """
1934        When clients and servers cannot agree on what protocol to use next
1935        the TLS connection does not get established.
1936        """
1937        select_args = []
1938
1939        def select(conn, options):
1940            select_args.append((conn, options))
1941            return b""
1942
1943        client_context = Context(SSLv23_METHOD)
1944        client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
1945
1946        server_context = Context(SSLv23_METHOD)
1947        server_context.set_alpn_select_callback(select)
1948
1949        # Necessary to actually accept the connection
1950        server_context.use_privatekey(
1951            load_privatekey(FILETYPE_PEM, server_key_pem)
1952        )
1953        server_context.use_certificate(
1954            load_certificate(FILETYPE_PEM, server_cert_pem)
1955        )
1956
1957        # Do a little connection to trigger the logic
1958        server = Connection(server_context, None)
1959        server.set_accept_state()
1960
1961        client = Connection(client_context, None)
1962        client.set_connect_state()
1963
1964        # If the client doesn't return anything, the connection will fail.
1965        with pytest.raises(Error):
1966            interact_in_memory(server, client)
1967
1968        assert select_args == [(server, [b"http/1.1", b"spdy/2"])]
1969
1970    def test_alpn_no_server_overlap(self):
1971        """
1972        A server can allow a TLS handshake to complete without
1973        agreeing to an application protocol by returning
1974        ``NO_OVERLAPPING_PROTOCOLS``.
1975        """
1976        refusal_args = []
1977
1978        def refusal(conn, options):
1979            refusal_args.append((conn, options))
1980            return NO_OVERLAPPING_PROTOCOLS
1981
1982        client_context = Context(SSLv23_METHOD)
1983        client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
1984
1985        server_context = Context(SSLv23_METHOD)
1986        server_context.set_alpn_select_callback(refusal)
1987
1988        # Necessary to actually accept the connection
1989        server_context.use_privatekey(
1990            load_privatekey(FILETYPE_PEM, server_key_pem)
1991        )
1992        server_context.use_certificate(
1993            load_certificate(FILETYPE_PEM, server_cert_pem)
1994        )
1995
1996        # Do a little connection to trigger the logic
1997        server = Connection(server_context, None)
1998        server.set_accept_state()
1999
2000        client = Connection(client_context, None)
2001        client.set_connect_state()
2002
2003        # Do the dance.
2004        interact_in_memory(server, client)
2005
2006        assert refusal_args == [(server, [b"http/1.1", b"spdy/2"])]
2007
2008        assert client.get_alpn_proto_negotiated() == b""
2009
2010    def test_alpn_select_cb_returns_invalid_value(self):
2011        """
2012        If the ALPN selection callback returns anything other than
2013        a bytestring or ``NO_OVERLAPPING_PROTOCOLS``, a
2014        :py:exc:`TypeError` is raised.
2015        """
2016        invalid_cb_args = []
2017
2018        def invalid_cb(conn, options):
2019            invalid_cb_args.append((conn, options))
2020            return u"can't return unicode"
2021
2022        client_context = Context(SSLv23_METHOD)
2023        client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
2024
2025        server_context = Context(SSLv23_METHOD)
2026        server_context.set_alpn_select_callback(invalid_cb)
2027
2028        # Necessary to actually accept the connection
2029        server_context.use_privatekey(
2030            load_privatekey(FILETYPE_PEM, server_key_pem)
2031        )
2032        server_context.use_certificate(
2033            load_certificate(FILETYPE_PEM, server_cert_pem)
2034        )
2035
2036        # Do a little connection to trigger the logic
2037        server = Connection(server_context, None)
2038        server.set_accept_state()
2039
2040        client = Connection(client_context, None)
2041        client.set_connect_state()
2042
2043        # Do the dance.
2044        with pytest.raises(TypeError):
2045            interact_in_memory(server, client)
2046
2047        assert invalid_cb_args == [(server, [b"http/1.1", b"spdy/2"])]
2048
2049        assert client.get_alpn_proto_negotiated() == b""
2050
2051    def test_alpn_no_server(self):
2052        """
2053        When clients and servers cannot agree on what protocol to use next
2054        because the server doesn't offer ALPN, no protocol is negotiated.
2055        """
2056        client_context = Context(SSLv23_METHOD)
2057        client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
2058
2059        server_context = Context(SSLv23_METHOD)
2060
2061        # Necessary to actually accept the connection
2062        server_context.use_privatekey(
2063            load_privatekey(FILETYPE_PEM, server_key_pem)
2064        )
2065        server_context.use_certificate(
2066            load_certificate(FILETYPE_PEM, server_cert_pem)
2067        )
2068
2069        # Do a little connection to trigger the logic
2070        server = Connection(server_context, None)
2071        server.set_accept_state()
2072
2073        client = Connection(client_context, None)
2074        client.set_connect_state()
2075
2076        # Do the dance.
2077        interact_in_memory(server, client)
2078
2079        assert client.get_alpn_proto_negotiated() == b""
2080
2081    def test_alpn_callback_exception(self):
2082        """
2083        We can handle exceptions in the ALPN select callback.
2084        """
2085        select_args = []
2086
2087        def select(conn, options):
2088            select_args.append((conn, options))
2089            raise TypeError()
2090
2091        client_context = Context(SSLv23_METHOD)
2092        client_context.set_alpn_protos([b"http/1.1", b"spdy/2"])
2093
2094        server_context = Context(SSLv23_METHOD)
2095        server_context.set_alpn_select_callback(select)
2096
2097        # Necessary to actually accept the connection
2098        server_context.use_privatekey(
2099            load_privatekey(FILETYPE_PEM, server_key_pem)
2100        )
2101        server_context.use_certificate(
2102            load_certificate(FILETYPE_PEM, server_cert_pem)
2103        )
2104
2105        # Do a little connection to trigger the logic
2106        server = Connection(server_context, None)
2107        server.set_accept_state()
2108
2109        client = Connection(client_context, None)
2110        client.set_connect_state()
2111
2112        with pytest.raises(TypeError):
2113            interact_in_memory(server, client)
2114        assert select_args == [(server, [b"http/1.1", b"spdy/2"])]
2115
2116
2117class TestSession(object):
2118    """
2119    Unit tests for :py:obj:`OpenSSL.SSL.Session`.
2120    """
2121
2122    def test_construction(self):
2123        """
2124        :py:class:`Session` can be constructed with no arguments, creating
2125        a new instance of that type.
2126        """
2127        new_session = Session()
2128        assert isinstance(new_session, Session)
2129
2130
2131class TestConnection(object):
2132    """
2133    Unit tests for `OpenSSL.SSL.Connection`.
2134    """
2135
2136    # XXX get_peer_certificate -> None
2137    # XXX sock_shutdown
2138    # XXX master_key -> TypeError
2139    # XXX server_random -> TypeError
2140    # XXX connect -> TypeError
2141    # XXX connect_ex -> TypeError
2142    # XXX set_connect_state -> TypeError
2143    # XXX set_accept_state -> TypeError
2144    # XXX do_handshake -> TypeError
2145    # XXX bio_read -> TypeError
2146    # XXX recv -> TypeError
2147    # XXX send -> TypeError
2148    # XXX bio_write -> TypeError
2149
2150    def test_type(self):
2151        """
2152        `Connection` can be used to create instances of that type.
2153        """
2154        ctx = Context(SSLv23_METHOD)
2155        assert is_consistent_type(Connection, "Connection", ctx, None)
2156
2157    @pytest.mark.parametrize("bad_context", [object(), "context", None, 1])
2158    def test_wrong_args(self, bad_context):
2159        """
2160        `Connection.__init__` raises `TypeError` if called with a non-`Context`
2161        instance argument.
2162        """
2163        with pytest.raises(TypeError):
2164            Connection(bad_context)
2165
2166    @pytest.mark.parametrize("bad_bio", [object(), None, 1, [1, 2, 3]])
2167    def test_bio_write_wrong_args(self, bad_bio):
2168        """
2169        `Connection.bio_write` raises `TypeError` if called with a non-bytes
2170        (or text) argument.
2171        """
2172        context = Context(SSLv23_METHOD)
2173        connection = Connection(context, None)
2174        with pytest.raises(TypeError):
2175            connection.bio_write(bad_bio)
2176
2177    def test_bio_write(self):
2178        """
2179        `Connection.bio_write` does not raise if called with bytes or
2180        bytearray, warns if called with text.
2181        """
2182        context = Context(SSLv23_METHOD)
2183        connection = Connection(context, None)
2184        connection.bio_write(b"xy")
2185        connection.bio_write(bytearray(b"za"))
2186        with pytest.warns(DeprecationWarning):
2187            connection.bio_write(u"deprecated")
2188
2189    def test_get_context(self):
2190        """
2191        `Connection.get_context` returns the `Context` instance used to
2192        construct the `Connection` instance.
2193        """
2194        context = Context(SSLv23_METHOD)
2195        connection = Connection(context, None)
2196        assert connection.get_context() is context
2197
2198    def test_set_context_wrong_args(self):
2199        """
2200        `Connection.set_context` raises `TypeError` if called with a
2201        non-`Context` instance argument.
2202        """
2203        ctx = Context(SSLv23_METHOD)
2204        connection = Connection(ctx, None)
2205        with pytest.raises(TypeError):
2206            connection.set_context(object())
2207        with pytest.raises(TypeError):
2208            connection.set_context("hello")
2209        with pytest.raises(TypeError):
2210            connection.set_context(1)
2211        assert ctx is connection.get_context()
2212
2213    def test_set_context(self):
2214        """
2215        `Connection.set_context` specifies a new `Context` instance to be
2216        used for the connection.
2217        """
2218        original = Context(SSLv23_METHOD)
2219        replacement = Context(SSLv23_METHOD)
2220        connection = Connection(original, None)
2221        connection.set_context(replacement)
2222        assert replacement is connection.get_context()
2223        # Lose our references to the contexts, just in case the Connection
2224        # isn't properly managing its own contributions to their reference
2225        # counts.
2226        del original, replacement
2227        collect()
2228
2229    def test_set_tlsext_host_name_wrong_args(self):
2230        """
2231        If `Connection.set_tlsext_host_name` is called with a non-byte string
2232        argument or a byte string with an embedded NUL, `TypeError` is raised.
2233        """
2234        conn = Connection(Context(SSLv23_METHOD), None)
2235        with pytest.raises(TypeError):
2236            conn.set_tlsext_host_name(object())
2237        with pytest.raises(TypeError):
2238            conn.set_tlsext_host_name(b"with\0null")
2239
2240        if not PY2:
2241            # On Python 3.x, don't accidentally implicitly convert from text.
2242            with pytest.raises(TypeError):
2243                conn.set_tlsext_host_name(b"example.com".decode("ascii"))
2244
2245    def test_pending(self):
2246        """
2247        `Connection.pending` returns the number of bytes available for
2248        immediate read.
2249        """
2250        connection = Connection(Context(SSLv23_METHOD), None)
2251        assert connection.pending() == 0
2252
2253    def test_peek(self):
2254        """
2255        `Connection.recv` peeks into the connection if `socket.MSG_PEEK` is
2256        passed.
2257        """
2258        server, client = loopback()
2259        server.send(b"xy")
2260        assert client.recv(2, MSG_PEEK) == b"xy"
2261        assert client.recv(2, MSG_PEEK) == b"xy"
2262        assert client.recv(2) == b"xy"
2263
2264    def test_connect_wrong_args(self):
2265        """
2266        `Connection.connect` raises `TypeError` if called with a non-address
2267        argument.
2268        """
2269        connection = Connection(Context(SSLv23_METHOD), socket_any_family())
2270        with pytest.raises(TypeError):
2271            connection.connect(None)
2272
2273    def test_connect_refused(self):
2274        """
2275        `Connection.connect` raises `socket.error` if the underlying socket
2276        connect method raises it.
2277        """
2278        client = socket_any_family()
2279        context = Context(SSLv23_METHOD)
2280        clientSSL = Connection(context, client)
2281        # pytest.raises here doesn't work because of a bug in py.test on Python
2282        # 2.6: https://github.com/pytest-dev/pytest/issues/988
2283        try:
2284            clientSSL.connect((loopback_address(client), 1))
2285        except error as e:
2286            exc = e
2287        assert exc.args[0] == ECONNREFUSED
2288
2289    def test_connect(self):
2290        """
2291        `Connection.connect` establishes a connection to the specified address.
2292        """
2293        port = socket_any_family()
2294        port.bind(("", 0))
2295        port.listen(3)
2296
2297        clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family))
2298        clientSSL.connect((loopback_address(port), port.getsockname()[1]))
2299        # XXX An assertion?  Or something?
2300
2301    @pytest.mark.skipif(
2302        platform == "darwin",
2303        reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4",
2304    )
2305    def test_connect_ex(self):
2306        """
2307        If there is a connection error, `Connection.connect_ex` returns the
2308        errno instead of raising an exception.
2309        """
2310        port = socket_any_family()
2311        port.bind(("", 0))
2312        port.listen(3)
2313
2314        clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family))
2315        clientSSL.setblocking(False)
2316        result = clientSSL.connect_ex(port.getsockname())
2317        expected = (EINPROGRESS, EWOULDBLOCK)
2318        assert result in expected
2319
2320    def test_accept(self):
2321        """
2322        `Connection.accept` accepts a pending connection attempt and returns a
2323        tuple of a new `Connection` (the accepted client) and the address the
2324        connection originated from.
2325        """
2326        ctx = Context(SSLv23_METHOD)
2327        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
2328        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
2329        port = socket_any_family()
2330        portSSL = Connection(ctx, port)
2331        portSSL.bind(("", 0))
2332        portSSL.listen(3)
2333
2334        clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family))
2335
2336        # Calling portSSL.getsockname() here to get the server IP address
2337        # sounds great, but frequently fails on Windows.
2338        clientSSL.connect((loopback_address(port), portSSL.getsockname()[1]))
2339
2340        serverSSL, address = portSSL.accept()
2341
2342        assert isinstance(serverSSL, Connection)
2343        assert serverSSL.get_context() is ctx
2344        assert address == clientSSL.getsockname()
2345
2346    def test_shutdown_wrong_args(self):
2347        """
2348        `Connection.set_shutdown` raises `TypeError` if called with arguments
2349        other than integers.
2350        """
2351        connection = Connection(Context(SSLv23_METHOD), None)
2352        with pytest.raises(TypeError):
2353            connection.set_shutdown(None)
2354
2355    def test_shutdown(self):
2356        """
2357        `Connection.shutdown` performs an SSL-level connection shutdown.
2358        """
2359        server, client = loopback()
2360        assert not server.shutdown()
2361        assert server.get_shutdown() == SENT_SHUTDOWN
2362        with pytest.raises(ZeroReturnError):
2363            client.recv(1024)
2364        assert client.get_shutdown() == RECEIVED_SHUTDOWN
2365        client.shutdown()
2366        assert client.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN)
2367        with pytest.raises(ZeroReturnError):
2368            server.recv(1024)
2369        assert server.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN)
2370
2371    def test_shutdown_closed(self):
2372        """
2373        If the underlying socket is closed, `Connection.shutdown` propagates
2374        the write error from the low level write call.
2375        """
2376        server, client = loopback()
2377        server.sock_shutdown(2)
2378        with pytest.raises(SysCallError) as exc:
2379            server.shutdown()
2380        if platform == "win32":
2381            assert exc.value.args[0] == ESHUTDOWN
2382        else:
2383            assert exc.value.args[0] == EPIPE
2384
2385    def test_shutdown_truncated(self):
2386        """
2387        If the underlying connection is truncated, `Connection.shutdown`
2388        raises an `Error`.
2389        """
2390        server_ctx = Context(SSLv23_METHOD)
2391        client_ctx = Context(SSLv23_METHOD)
2392        server_ctx.use_privatekey(
2393            load_privatekey(FILETYPE_PEM, server_key_pem)
2394        )
2395        server_ctx.use_certificate(
2396            load_certificate(FILETYPE_PEM, server_cert_pem)
2397        )
2398        server = Connection(server_ctx, None)
2399        client = Connection(client_ctx, None)
2400        handshake_in_memory(client, server)
2401        assert not server.shutdown()
2402        with pytest.raises(WantReadError):
2403            server.shutdown()
2404        server.bio_shutdown()
2405        with pytest.raises(Error):
2406            server.shutdown()
2407
2408    def test_set_shutdown(self):
2409        """
2410        `Connection.set_shutdown` sets the state of the SSL connection
2411        shutdown process.
2412        """
2413        connection = Connection(Context(SSLv23_METHOD), socket_any_family())
2414        connection.set_shutdown(RECEIVED_SHUTDOWN)
2415        assert connection.get_shutdown() == RECEIVED_SHUTDOWN
2416
2417    def test_state_string(self):
2418        """
2419        `Connection.state_string` verbosely describes the current state of
2420        the `Connection`.
2421        """
2422        server, client = socket_pair()
2423        server = loopback_server_factory(server)
2424        client = loopback_client_factory(client)
2425
2426        assert server.get_state_string() in [
2427            b"before/accept initialization",
2428            b"before SSL initialization",
2429        ]
2430        assert client.get_state_string() in [
2431            b"before/connect initialization",
2432            b"before SSL initialization",
2433        ]
2434
2435    def test_app_data(self):
2436        """
2437        Any object can be set as app data by passing it to
2438        `Connection.set_app_data` and later retrieved with
2439        `Connection.get_app_data`.
2440        """
2441        conn = Connection(Context(SSLv23_METHOD), None)
2442        assert None is conn.get_app_data()
2443        app_data = object()
2444        conn.set_app_data(app_data)
2445        assert conn.get_app_data() is app_data
2446
2447    def test_makefile(self):
2448        """
2449        `Connection.makefile` is not implemented and calling that
2450        method raises `NotImplementedError`.
2451        """
2452        conn = Connection(Context(SSLv23_METHOD), None)
2453        with pytest.raises(NotImplementedError):
2454            conn.makefile()
2455
2456    def test_get_certificate(self):
2457        """
2458        `Connection.get_certificate` returns the local certificate.
2459        """
2460        chain = _create_certificate_chain()
2461        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
2462
2463        context = Context(SSLv23_METHOD)
2464        context.use_certificate(scert)
2465        client = Connection(context, None)
2466        cert = client.get_certificate()
2467        assert cert is not None
2468        assert "Server Certificate" == cert.get_subject().CN
2469
2470    def test_get_certificate_none(self):
2471        """
2472        `Connection.get_certificate` returns the local certificate.
2473
2474        If there is no certificate, it returns None.
2475        """
2476        context = Context(SSLv23_METHOD)
2477        client = Connection(context, None)
2478        cert = client.get_certificate()
2479        assert cert is None
2480
2481    def test_get_peer_cert_chain(self):
2482        """
2483        `Connection.get_peer_cert_chain` returns a list of certificates
2484        which the connected server returned for the certification verification.
2485        """
2486        chain = _create_certificate_chain()
2487        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
2488
2489        serverContext = Context(SSLv23_METHOD)
2490        serverContext.use_privatekey(skey)
2491        serverContext.use_certificate(scert)
2492        serverContext.add_extra_chain_cert(icert)
2493        serverContext.add_extra_chain_cert(cacert)
2494        server = Connection(serverContext, None)
2495        server.set_accept_state()
2496
2497        # Create the client
2498        clientContext = Context(SSLv23_METHOD)
2499        clientContext.set_verify(VERIFY_NONE, verify_cb)
2500        client = Connection(clientContext, None)
2501        client.set_connect_state()
2502
2503        interact_in_memory(client, server)
2504
2505        chain = client.get_peer_cert_chain()
2506        assert len(chain) == 3
2507        assert "Server Certificate" == chain[0].get_subject().CN
2508        assert "Intermediate Certificate" == chain[1].get_subject().CN
2509        assert "Authority Certificate" == chain[2].get_subject().CN
2510
2511    def test_get_peer_cert_chain_none(self):
2512        """
2513        `Connection.get_peer_cert_chain` returns `None` if the peer sends
2514        no certificate chain.
2515        """
2516        ctx = Context(SSLv23_METHOD)
2517        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
2518        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
2519        server = Connection(ctx, None)
2520        server.set_accept_state()
2521        client = Connection(Context(SSLv23_METHOD), None)
2522        client.set_connect_state()
2523        interact_in_memory(client, server)
2524        assert None is server.get_peer_cert_chain()
2525
2526    def test_get_verified_chain(self):
2527        """
2528        `Connection.get_verified_chain` returns a list of certificates
2529        which the connected server returned for the certification verification.
2530        """
2531        chain = _create_certificate_chain()
2532        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
2533
2534        serverContext = Context(SSLv23_METHOD)
2535        serverContext.use_privatekey(skey)
2536        serverContext.use_certificate(scert)
2537        serverContext.add_extra_chain_cert(icert)
2538        serverContext.add_extra_chain_cert(cacert)
2539        server = Connection(serverContext, None)
2540        server.set_accept_state()
2541
2542        # Create the client
2543        clientContext = Context(SSLv23_METHOD)
2544        # cacert is self-signed so the client must trust it for verification
2545        # to succeed.
2546        clientContext.get_cert_store().add_cert(cacert)
2547        clientContext.set_verify(VERIFY_PEER, verify_cb)
2548        client = Connection(clientContext, None)
2549        client.set_connect_state()
2550
2551        interact_in_memory(client, server)
2552
2553        chain = client.get_verified_chain()
2554        assert len(chain) == 3
2555        assert "Server Certificate" == chain[0].get_subject().CN
2556        assert "Intermediate Certificate" == chain[1].get_subject().CN
2557        assert "Authority Certificate" == chain[2].get_subject().CN
2558
2559    def test_get_verified_chain_none(self):
2560        """
2561        `Connection.get_verified_chain` returns `None` if the peer sends
2562        no certificate chain.
2563        """
2564        ctx = Context(SSLv23_METHOD)
2565        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
2566        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
2567        server = Connection(ctx, None)
2568        server.set_accept_state()
2569        client = Connection(Context(SSLv23_METHOD), None)
2570        client.set_connect_state()
2571        interact_in_memory(client, server)
2572        assert None is server.get_verified_chain()
2573
2574    def test_get_verified_chain_unconnected(self):
2575        """
2576        `Connection.get_verified_chain` returns `None` when used with an object
2577        which has not been connected.
2578        """
2579        ctx = Context(SSLv23_METHOD)
2580        server = Connection(ctx, None)
2581        assert None is server.get_verified_chain()
2582
2583    def test_get_session_unconnected(self):
2584        """
2585        `Connection.get_session` returns `None` when used with an object
2586        which has not been connected.
2587        """
2588        ctx = Context(SSLv23_METHOD)
2589        server = Connection(ctx, None)
2590        session = server.get_session()
2591        assert None is session
2592
2593    def test_server_get_session(self):
2594        """
2595        On the server side of a connection, `Connection.get_session` returns a
2596        `Session` instance representing the SSL session for that connection.
2597        """
2598        server, client = loopback()
2599        session = server.get_session()
2600        assert isinstance(session, Session)
2601
2602    def test_client_get_session(self):
2603        """
2604        On the client side of a connection, `Connection.get_session`
2605        returns a `Session` instance representing the SSL session for
2606        that connection.
2607        """
2608        server, client = loopback()
2609        session = client.get_session()
2610        assert isinstance(session, Session)
2611
2612    def test_set_session_wrong_args(self):
2613        """
2614        `Connection.set_session` raises `TypeError` if called with an object
2615        that is not an instance of `Session`.
2616        """
2617        ctx = Context(SSLv23_METHOD)
2618        connection = Connection(ctx, None)
2619        with pytest.raises(TypeError):
2620            connection.set_session(123)
2621        with pytest.raises(TypeError):
2622            connection.set_session("hello")
2623        with pytest.raises(TypeError):
2624            connection.set_session(object())
2625
2626    def test_client_set_session(self):
2627        """
2628        `Connection.set_session`, when used prior to a connection being
2629        established, accepts a `Session` instance and causes an attempt to
2630        re-use the session it represents when the SSL handshake is performed.
2631        """
2632        key = load_privatekey(FILETYPE_PEM, server_key_pem)
2633        cert = load_certificate(FILETYPE_PEM, server_cert_pem)
2634        ctx = Context(TLSv1_2_METHOD)
2635        ctx.use_privatekey(key)
2636        ctx.use_certificate(cert)
2637        ctx.set_session_id("unity-test")
2638
2639        def makeServer(socket):
2640            server = Connection(ctx, socket)
2641            server.set_accept_state()
2642            return server
2643
2644        originalServer, originalClient = loopback(server_factory=makeServer)
2645        originalSession = originalClient.get_session()
2646
2647        def makeClient(socket):
2648            client = loopback_client_factory(socket)
2649            client.set_session(originalSession)
2650            return client
2651
2652        resumedServer, resumedClient = loopback(
2653            server_factory=makeServer, client_factory=makeClient
2654        )
2655
2656        # This is a proxy: in general, we have no access to any unique
2657        # identifier for the session (new enough versions of OpenSSL expose
2658        # a hash which could be usable, but "new enough" is very, very new).
2659        # Instead, exploit the fact that the master key is re-used if the
2660        # session is re-used.  As long as the master key for the two
2661        # connections is the same, the session was re-used!
2662        assert originalServer.master_key() == resumedServer.master_key()
2663
2664    def test_set_session_wrong_method(self):
2665        """
2666        If `Connection.set_session` is passed a `Session` instance associated
2667        with a context using a different SSL method than the `Connection`
2668        is using, a `OpenSSL.SSL.Error` is raised.
2669        """
2670        v1 = TLSv1_2_METHOD
2671        v2 = TLSv1_METHOD
2672
2673        key = load_privatekey(FILETYPE_PEM, server_key_pem)
2674        cert = load_certificate(FILETYPE_PEM, server_cert_pem)
2675        ctx = Context(v1)
2676        ctx.use_privatekey(key)
2677        ctx.use_certificate(cert)
2678        ctx.set_session_id(b"unity-test")
2679
2680        def makeServer(socket):
2681            server = Connection(ctx, socket)
2682            server.set_accept_state()
2683            return server
2684
2685        def makeOriginalClient(socket):
2686            client = Connection(Context(v1), socket)
2687            client.set_connect_state()
2688            return client
2689
2690        originalServer, originalClient = loopback(
2691            server_factory=makeServer, client_factory=makeOriginalClient
2692        )
2693        originalSession = originalClient.get_session()
2694
2695        def makeClient(socket):
2696            # Intentionally use a different, incompatible method here.
2697            client = Connection(Context(v2), socket)
2698            client.set_connect_state()
2699            client.set_session(originalSession)
2700            return client
2701
2702        with pytest.raises(Error):
2703            loopback(client_factory=makeClient, server_factory=makeServer)
2704
2705    def test_wantWriteError(self):
2706        """
2707        `Connection` methods which generate output raise
2708        `OpenSSL.SSL.WantWriteError` if writing to the connection's BIO
2709        fail indicating a should-write state.
2710        """
2711        client_socket, server_socket = socket_pair()
2712        # Fill up the client's send buffer so Connection won't be able to write
2713        # anything.  Only write a single byte at a time so we can be sure we
2714        # completely fill the buffer.  Even though the socket API is allowed to
2715        # signal a short write via its return value it seems this doesn't
2716        # always happen on all platforms (FreeBSD and OS X particular) for the
2717        # very last bit of available buffer space.
2718        msg = b"x"
2719        for i in range(1024 * 1024 * 64):
2720            try:
2721                client_socket.send(msg)
2722            except error as e:
2723                if e.errno == EWOULDBLOCK:
2724                    break
2725                raise
2726        else:
2727            pytest.fail(
2728                "Failed to fill socket buffer, cannot test BIO want write"
2729            )
2730
2731        ctx = Context(SSLv23_METHOD)
2732        conn = Connection(ctx, client_socket)
2733        # Client's speak first, so make it an SSL client
2734        conn.set_connect_state()
2735        with pytest.raises(WantWriteError):
2736            conn.do_handshake()
2737
2738    # XXX want_read
2739
2740    def test_get_finished_before_connect(self):
2741        """
2742        `Connection.get_finished` returns `None` before TLS handshake
2743        is completed.
2744        """
2745        ctx = Context(SSLv23_METHOD)
2746        connection = Connection(ctx, None)
2747        assert connection.get_finished() is None
2748
2749    def test_get_peer_finished_before_connect(self):
2750        """
2751        `Connection.get_peer_finished` returns `None` before TLS handshake
2752        is completed.
2753        """
2754        ctx = Context(SSLv23_METHOD)
2755        connection = Connection(ctx, None)
2756        assert connection.get_peer_finished() is None
2757
2758    def test_get_finished(self):
2759        """
2760        `Connection.get_finished` method returns the TLS Finished message send
2761        from client, or server. Finished messages are send during
2762        TLS handshake.
2763        """
2764        server, client = loopback()
2765
2766        assert server.get_finished() is not None
2767        assert len(server.get_finished()) > 0
2768
2769    def test_get_peer_finished(self):
2770        """
2771        `Connection.get_peer_finished` method returns the TLS Finished
2772        message received from client, or server. Finished messages are send
2773        during TLS handshake.
2774        """
2775        server, client = loopback()
2776
2777        assert server.get_peer_finished() is not None
2778        assert len(server.get_peer_finished()) > 0
2779
2780    def test_tls_finished_message_symmetry(self):
2781        """
2782        The TLS Finished message send by server must be the TLS Finished
2783        message received by client.
2784
2785        The TLS Finished message send by client must be the TLS Finished
2786        message received by server.
2787        """
2788        server, client = loopback()
2789
2790        assert server.get_finished() == client.get_peer_finished()
2791        assert client.get_finished() == server.get_peer_finished()
2792
2793    def test_get_cipher_name_before_connect(self):
2794        """
2795        `Connection.get_cipher_name` returns `None` if no connection
2796        has been established.
2797        """
2798        ctx = Context(SSLv23_METHOD)
2799        conn = Connection(ctx, None)
2800        assert conn.get_cipher_name() is None
2801
2802    def test_get_cipher_name(self):
2803        """
2804        `Connection.get_cipher_name` returns a `unicode` string giving the
2805        name of the currently used cipher.
2806        """
2807        server, client = loopback()
2808        server_cipher_name, client_cipher_name = (
2809            server.get_cipher_name(),
2810            client.get_cipher_name(),
2811        )
2812
2813        assert isinstance(server_cipher_name, text_type)
2814        assert isinstance(client_cipher_name, text_type)
2815
2816        assert server_cipher_name == client_cipher_name
2817
2818    def test_get_cipher_version_before_connect(self):
2819        """
2820        `Connection.get_cipher_version` returns `None` if no connection
2821        has been established.
2822        """
2823        ctx = Context(SSLv23_METHOD)
2824        conn = Connection(ctx, None)
2825        assert conn.get_cipher_version() is None
2826
2827    def test_get_cipher_version(self):
2828        """
2829        `Connection.get_cipher_version` returns a `unicode` string giving
2830        the protocol name of the currently used cipher.
2831        """
2832        server, client = loopback()
2833        server_cipher_version, client_cipher_version = (
2834            server.get_cipher_version(),
2835            client.get_cipher_version(),
2836        )
2837
2838        assert isinstance(server_cipher_version, text_type)
2839        assert isinstance(client_cipher_version, text_type)
2840
2841        assert server_cipher_version == client_cipher_version
2842
2843    def test_get_cipher_bits_before_connect(self):
2844        """
2845        `Connection.get_cipher_bits` returns `None` if no connection has
2846        been established.
2847        """
2848        ctx = Context(SSLv23_METHOD)
2849        conn = Connection(ctx, None)
2850        assert conn.get_cipher_bits() is None
2851
2852    def test_get_cipher_bits(self):
2853        """
2854        `Connection.get_cipher_bits` returns the number of secret bits
2855        of the currently used cipher.
2856        """
2857        server, client = loopback()
2858        server_cipher_bits, client_cipher_bits = (
2859            server.get_cipher_bits(),
2860            client.get_cipher_bits(),
2861        )
2862
2863        assert isinstance(server_cipher_bits, int)
2864        assert isinstance(client_cipher_bits, int)
2865
2866        assert server_cipher_bits == client_cipher_bits
2867
2868    def test_get_protocol_version_name(self):
2869        """
2870        `Connection.get_protocol_version_name()` returns a string giving the
2871        protocol version of the current connection.
2872        """
2873        server, client = loopback()
2874        client_protocol_version_name = client.get_protocol_version_name()
2875        server_protocol_version_name = server.get_protocol_version_name()
2876
2877        assert isinstance(server_protocol_version_name, text_type)
2878        assert isinstance(client_protocol_version_name, text_type)
2879
2880        assert server_protocol_version_name == client_protocol_version_name
2881
2882    def test_get_protocol_version(self):
2883        """
2884        `Connection.get_protocol_version()` returns an integer
2885        giving the protocol version of the current connection.
2886        """
2887        server, client = loopback()
2888        client_protocol_version = client.get_protocol_version()
2889        server_protocol_version = server.get_protocol_version()
2890
2891        assert isinstance(server_protocol_version, int)
2892        assert isinstance(client_protocol_version, int)
2893
2894        assert server_protocol_version == client_protocol_version
2895
2896    def test_wantReadError(self):
2897        """
2898        `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are
2899        no bytes available to be read from the BIO.
2900        """
2901        ctx = Context(SSLv23_METHOD)
2902        conn = Connection(ctx, None)
2903        with pytest.raises(WantReadError):
2904            conn.bio_read(1024)
2905
2906    @pytest.mark.parametrize("bufsize", [1.0, None, object(), "bufsize"])
2907    def test_bio_read_wrong_args(self, bufsize):
2908        """
2909        `Connection.bio_read` raises `TypeError` if passed a non-integer
2910        argument.
2911        """
2912        ctx = Context(SSLv23_METHOD)
2913        conn = Connection(ctx, None)
2914        with pytest.raises(TypeError):
2915            conn.bio_read(bufsize)
2916
2917    def test_buffer_size(self):
2918        """
2919        `Connection.bio_read` accepts an integer giving the maximum number
2920        of bytes to read and return.
2921        """
2922        ctx = Context(SSLv23_METHOD)
2923        conn = Connection(ctx, None)
2924        conn.set_connect_state()
2925        try:
2926            conn.do_handshake()
2927        except WantReadError:
2928            pass
2929        data = conn.bio_read(2)
2930        assert 2 == len(data)
2931
2932
2933class TestConnectionGetCipherList(object):
2934    """
2935    Tests for `Connection.get_cipher_list`.
2936    """
2937
2938    def test_result(self):
2939        """
2940        `Connection.get_cipher_list` returns a list of `bytes` giving the
2941        names of the ciphers which might be used.
2942        """
2943        connection = Connection(Context(SSLv23_METHOD), None)
2944        ciphers = connection.get_cipher_list()
2945        assert isinstance(ciphers, list)
2946        for cipher in ciphers:
2947            assert isinstance(cipher, str)
2948
2949
2950class VeryLarge(bytes):
2951    """
2952    Mock object so that we don't have to allocate 2**31 bytes
2953    """
2954
2955    def __len__(self):
2956        return 2 ** 31
2957
2958
2959class TestConnectionSend(object):
2960    """
2961    Tests for `Connection.send`.
2962    """
2963
2964    def test_wrong_args(self):
2965        """
2966        When called with arguments other than string argument for its first
2967        parameter, `Connection.send` raises `TypeError`.
2968        """
2969        connection = Connection(Context(SSLv23_METHOD), None)
2970        with pytest.raises(TypeError):
2971            connection.send(object())
2972        with pytest.raises(TypeError):
2973            connection.send([1, 2, 3])
2974
2975    def test_short_bytes(self):
2976        """
2977        When passed a short byte string, `Connection.send` transmits all of it
2978        and returns the number of bytes sent.
2979        """
2980        server, client = loopback()
2981        count = server.send(b"xy")
2982        assert count == 2
2983        assert client.recv(2) == b"xy"
2984
2985    def test_text(self):
2986        """
2987        When passed a text, `Connection.send` transmits all of it and
2988        returns the number of bytes sent. It also raises a DeprecationWarning.
2989        """
2990        server, client = loopback()
2991        with pytest.warns(DeprecationWarning) as w:
2992            simplefilter("always")
2993            count = server.send(b"xy".decode("ascii"))
2994            assert "{0} for buf is no longer accepted, use bytes".format(
2995                WARNING_TYPE_EXPECTED
2996            ) == str(w[-1].message)
2997        assert count == 2
2998        assert client.recv(2) == b"xy"
2999
3000    def test_short_memoryview(self):
3001        """
3002        When passed a memoryview onto a small number of bytes,
3003        `Connection.send` transmits all of them and returns the number
3004        of bytes sent.
3005        """
3006        server, client = loopback()
3007        count = server.send(memoryview(b"xy"))
3008        assert count == 2
3009        assert client.recv(2) == b"xy"
3010
3011    def test_short_bytearray(self):
3012        """
3013        When passed a short bytearray, `Connection.send` transmits all of
3014        it and returns the number of bytes sent.
3015        """
3016        server, client = loopback()
3017        count = server.send(bytearray(b"xy"))
3018        assert count == 2
3019        assert client.recv(2) == b"xy"
3020
3021    @skip_if_py3
3022    def test_short_buffer(self):
3023        """
3024        When passed a buffer containing a small number of bytes,
3025        `Connection.send` transmits all of them and returns the number
3026        of bytes sent.
3027        """
3028        server, client = loopback()
3029        count = server.send(buffer(b"xy"))  # noqa: F821
3030        assert count == 2
3031        assert client.recv(2) == b"xy"
3032
3033    @pytest.mark.skipif(
3034        sys.maxsize < 2 ** 31,
3035        reason="sys.maxsize < 2**31 - test requires 64 bit",
3036    )
3037    def test_buf_too_large(self):
3038        """
3039        When passed a buffer containing >= 2**31 bytes,
3040        `Connection.send` bails out as SSL_write only
3041        accepts an int for the buffer length.
3042        """
3043        connection = Connection(Context(SSLv23_METHOD), None)
3044        with pytest.raises(ValueError) as exc_info:
3045            connection.send(VeryLarge())
3046        exc_info.match(r"Cannot send more than .+ bytes at once")
3047
3048
3049def _make_memoryview(size):
3050    """
3051    Create a new ``memoryview`` wrapped around a ``bytearray`` of the given
3052    size.
3053    """
3054    return memoryview(bytearray(size))
3055
3056
3057class TestConnectionRecvInto(object):
3058    """
3059    Tests for `Connection.recv_into`.
3060    """
3061
3062    def _no_length_test(self, factory):
3063        """
3064        Assert that when the given buffer is passed to `Connection.recv_into`,
3065        whatever bytes are available to be received that fit into that buffer
3066        are written into that buffer.
3067        """
3068        output_buffer = factory(5)
3069
3070        server, client = loopback()
3071        server.send(b"xy")
3072
3073        assert client.recv_into(output_buffer) == 2
3074        assert output_buffer == bytearray(b"xy\x00\x00\x00")
3075
3076    def test_bytearray_no_length(self):
3077        """
3078        `Connection.recv_into` can be passed a `bytearray` instance and data
3079        in the receive buffer is written to it.
3080        """
3081        self._no_length_test(bytearray)
3082
3083    def _respects_length_test(self, factory):
3084        """
3085        Assert that when the given buffer is passed to `Connection.recv_into`
3086        along with a value for `nbytes` that is less than the size of that
3087        buffer, only `nbytes` bytes are written into the buffer.
3088        """
3089        output_buffer = factory(10)
3090
3091        server, client = loopback()
3092        server.send(b"abcdefghij")
3093
3094        assert client.recv_into(output_buffer, 5) == 5
3095        assert output_buffer == bytearray(b"abcde\x00\x00\x00\x00\x00")
3096
3097    def test_bytearray_respects_length(self):
3098        """
3099        When called with a `bytearray` instance, `Connection.recv_into`
3100        respects the `nbytes` parameter and doesn't copy in more than that
3101        number of bytes.
3102        """
3103        self._respects_length_test(bytearray)
3104
3105    def _doesnt_overfill_test(self, factory):
3106        """
3107        Assert that if there are more bytes available to be read from the
3108        receive buffer than would fit into the buffer passed to
3109        `Connection.recv_into`, only as many as fit are written into it.
3110        """
3111        output_buffer = factory(5)
3112
3113        server, client = loopback()
3114        server.send(b"abcdefghij")
3115
3116        assert client.recv_into(output_buffer) == 5
3117        assert output_buffer == bytearray(b"abcde")
3118        rest = client.recv(5)
3119        assert b"fghij" == rest
3120
3121    def test_bytearray_doesnt_overfill(self):
3122        """
3123        When called with a `bytearray` instance, `Connection.recv_into`
3124        respects the size of the array and doesn't write more bytes into it
3125        than will fit.
3126        """
3127        self._doesnt_overfill_test(bytearray)
3128
3129    def test_bytearray_really_doesnt_overfill(self):
3130        """
3131        When called with a `bytearray` instance and an `nbytes` value that is
3132        too large, `Connection.recv_into` respects the size of the array and
3133        not the `nbytes` value and doesn't write more bytes into the buffer
3134        than will fit.
3135        """
3136        self._doesnt_overfill_test(bytearray)
3137
3138    def test_peek(self):
3139        server, client = loopback()
3140        server.send(b"xy")
3141
3142        for _ in range(2):
3143            output_buffer = bytearray(5)
3144            assert client.recv_into(output_buffer, flags=MSG_PEEK) == 2
3145            assert output_buffer == bytearray(b"xy\x00\x00\x00")
3146
3147    def test_memoryview_no_length(self):
3148        """
3149        `Connection.recv_into` can be passed a `memoryview` instance and data
3150        in the receive buffer is written to it.
3151        """
3152        self._no_length_test(_make_memoryview)
3153
3154    def test_memoryview_respects_length(self):
3155        """
3156        When called with a `memoryview` instance, `Connection.recv_into`
3157        respects the ``nbytes`` parameter and doesn't copy more than that
3158        number of bytes in.
3159        """
3160        self._respects_length_test(_make_memoryview)
3161
3162    def test_memoryview_doesnt_overfill(self):
3163        """
3164        When called with a `memoryview` instance, `Connection.recv_into`
3165        respects the size of the array and doesn't write more bytes into it
3166        than will fit.
3167        """
3168        self._doesnt_overfill_test(_make_memoryview)
3169
3170    def test_memoryview_really_doesnt_overfill(self):
3171        """
3172        When called with a `memoryview` instance and an `nbytes` value that is
3173        too large, `Connection.recv_into` respects the size of the array and
3174        not the `nbytes` value and doesn't write more bytes into the buffer
3175        than will fit.
3176        """
3177        self._doesnt_overfill_test(_make_memoryview)
3178
3179
3180class TestConnectionSendall(object):
3181    """
3182    Tests for `Connection.sendall`.
3183    """
3184
3185    def test_wrong_args(self):
3186        """
3187        When called with arguments other than a string argument for its first
3188        parameter, `Connection.sendall` raises `TypeError`.
3189        """
3190        connection = Connection(Context(SSLv23_METHOD), None)
3191        with pytest.raises(TypeError):
3192            connection.sendall(object())
3193        with pytest.raises(TypeError):
3194            connection.sendall([1, 2, 3])
3195
3196    def test_short(self):
3197        """
3198        `Connection.sendall` transmits all of the bytes in the string
3199        passed to it.
3200        """
3201        server, client = loopback()
3202        server.sendall(b"x")
3203        assert client.recv(1) == b"x"
3204
3205    def test_text(self):
3206        """
3207        `Connection.sendall` transmits all the content in the string passed
3208        to it, raising a DeprecationWarning in case of this being a text.
3209        """
3210        server, client = loopback()
3211        with pytest.warns(DeprecationWarning) as w:
3212            simplefilter("always")
3213            server.sendall(b"x".decode("ascii"))
3214            assert "{0} for buf is no longer accepted, use bytes".format(
3215                WARNING_TYPE_EXPECTED
3216            ) == str(w[-1].message)
3217        assert client.recv(1) == b"x"
3218
3219    def test_short_memoryview(self):
3220        """
3221        When passed a memoryview onto a small number of bytes,
3222        `Connection.sendall` transmits all of them.
3223        """
3224        server, client = loopback()
3225        server.sendall(memoryview(b"x"))
3226        assert client.recv(1) == b"x"
3227
3228    @skip_if_py3
3229    def test_short_buffers(self):
3230        """
3231        When passed a buffer containing a small number of bytes,
3232        `Connection.sendall` transmits all of them.
3233        """
3234        server, client = loopback()
3235        count = server.sendall(buffer(b"xy"))  # noqa: F821
3236        assert count == 2
3237        assert client.recv(2) == b"xy"
3238
3239    def test_long(self):
3240        """
3241        `Connection.sendall` transmits all the bytes in the string passed to it
3242        even if this requires multiple calls of an underlying write function.
3243        """
3244        server, client = loopback()
3245        # Should be enough, underlying SSL_write should only do 16k at a time.
3246        # On Windows, after 32k of bytes the write will block (forever
3247        # - because no one is yet reading).
3248        message = b"x" * (1024 * 32 - 1) + b"y"
3249        server.sendall(message)
3250        accum = []
3251        received = 0
3252        while received < len(message):
3253            data = client.recv(1024)
3254            accum.append(data)
3255            received += len(data)
3256        assert message == b"".join(accum)
3257
3258    def test_closed(self):
3259        """
3260        If the underlying socket is closed, `Connection.sendall` propagates the
3261        write error from the low level write call.
3262        """
3263        server, client = loopback()
3264        server.sock_shutdown(2)
3265        with pytest.raises(SysCallError) as err:
3266            server.sendall(b"hello, world")
3267        if platform == "win32":
3268            assert err.value.args[0] == ESHUTDOWN
3269        else:
3270            assert err.value.args[0] == EPIPE
3271
3272
3273class TestConnectionRenegotiate(object):
3274    """
3275    Tests for SSL renegotiation APIs.
3276    """
3277
3278    def test_total_renegotiations(self):
3279        """
3280        `Connection.total_renegotiations` returns `0` before any renegotiations
3281        have happened.
3282        """
3283        connection = Connection(Context(SSLv23_METHOD), None)
3284        assert connection.total_renegotiations() == 0
3285
3286    def test_renegotiate(self):
3287        """
3288        Go through a complete renegotiation cycle.
3289        """
3290        server, client = loopback(
3291            lambda s: loopback_server_factory(s, TLSv1_2_METHOD),
3292            lambda s: loopback_client_factory(s, TLSv1_2_METHOD),
3293        )
3294
3295        server.send(b"hello world")
3296
3297        assert b"hello world" == client.recv(len(b"hello world"))
3298
3299        assert 0 == server.total_renegotiations()
3300        assert False is server.renegotiate_pending()
3301
3302        assert True is server.renegotiate()
3303
3304        assert True is server.renegotiate_pending()
3305
3306        server.setblocking(False)
3307        client.setblocking(False)
3308
3309        client.do_handshake()
3310        server.do_handshake()
3311
3312        assert 1 == server.total_renegotiations()
3313        while False is server.renegotiate_pending():
3314            pass
3315
3316
3317class TestError(object):
3318    """
3319    Unit tests for `OpenSSL.SSL.Error`.
3320    """
3321
3322    def test_type(self):
3323        """
3324        `Error` is an exception type.
3325        """
3326        assert issubclass(Error, Exception)
3327        assert Error.__name__ == "Error"
3328
3329
3330class TestConstants(object):
3331    """
3332    Tests for the values of constants exposed in `OpenSSL.SSL`.
3333
3334    These are values defined by OpenSSL intended only to be used as flags to
3335    OpenSSL APIs.  The only assertions it seems can be made about them is
3336    their values.
3337    """
3338
3339    @pytest.mark.skipif(
3340        OP_NO_QUERY_MTU is None,
3341        reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old",
3342    )
3343    def test_op_no_query_mtu(self):
3344        """
3345        The value of `OpenSSL.SSL.OP_NO_QUERY_MTU` is 0x1000, the value
3346        of `SSL_OP_NO_QUERY_MTU` defined by `openssl/ssl.h`.
3347        """
3348        assert OP_NO_QUERY_MTU == 0x1000
3349
3350    @pytest.mark.skipif(
3351        OP_COOKIE_EXCHANGE is None,
3352        reason="OP_COOKIE_EXCHANGE unavailable - "
3353        "OpenSSL version may be too old",
3354    )
3355    def test_op_cookie_exchange(self):
3356        """
3357        The value of `OpenSSL.SSL.OP_COOKIE_EXCHANGE` is 0x2000, the
3358        value of `SSL_OP_COOKIE_EXCHANGE` defined by `openssl/ssl.h`.
3359        """
3360        assert OP_COOKIE_EXCHANGE == 0x2000
3361
3362    @pytest.mark.skipif(
3363        OP_NO_TICKET is None,
3364        reason="OP_NO_TICKET unavailable - OpenSSL version may be too old",
3365    )
3366    def test_op_no_ticket(self):
3367        """
3368        The value of `OpenSSL.SSL.OP_NO_TICKET` is 0x4000, the value of
3369        `SSL_OP_NO_TICKET` defined by `openssl/ssl.h`.
3370        """
3371        assert OP_NO_TICKET == 0x4000
3372
3373    @pytest.mark.skipif(
3374        OP_NO_COMPRESSION is None,
3375        reason=(
3376            "OP_NO_COMPRESSION unavailable - OpenSSL version may be too old"
3377        ),
3378    )
3379    def test_op_no_compression(self):
3380        """
3381        The value of `OpenSSL.SSL.OP_NO_COMPRESSION` is 0x20000, the
3382        value of `SSL_OP_NO_COMPRESSION` defined by `openssl/ssl.h`.
3383        """
3384        assert OP_NO_COMPRESSION == 0x20000
3385
3386    def test_sess_cache_off(self):
3387        """
3388        The value of `OpenSSL.SSL.SESS_CACHE_OFF` 0x0, the value of
3389        `SSL_SESS_CACHE_OFF` defined by `openssl/ssl.h`.
3390        """
3391        assert 0x0 == SESS_CACHE_OFF
3392
3393    def test_sess_cache_client(self):
3394        """
3395        The value of `OpenSSL.SSL.SESS_CACHE_CLIENT` 0x1, the value of
3396        `SSL_SESS_CACHE_CLIENT` defined by `openssl/ssl.h`.
3397        """
3398        assert 0x1 == SESS_CACHE_CLIENT
3399
3400    def test_sess_cache_server(self):
3401        """
3402        The value of `OpenSSL.SSL.SESS_CACHE_SERVER` 0x2, the value of
3403        `SSL_SESS_CACHE_SERVER` defined by `openssl/ssl.h`.
3404        """
3405        assert 0x2 == SESS_CACHE_SERVER
3406
3407    def test_sess_cache_both(self):
3408        """
3409        The value of `OpenSSL.SSL.SESS_CACHE_BOTH` 0x3, the value of
3410        `SSL_SESS_CACHE_BOTH` defined by `openssl/ssl.h`.
3411        """
3412        assert 0x3 == SESS_CACHE_BOTH
3413
3414    def test_sess_cache_no_auto_clear(self):
3415        """
3416        The value of `OpenSSL.SSL.SESS_CACHE_NO_AUTO_CLEAR` 0x80, the
3417        value of `SSL_SESS_CACHE_NO_AUTO_CLEAR` defined by
3418        `openssl/ssl.h`.
3419        """
3420        assert 0x80 == SESS_CACHE_NO_AUTO_CLEAR
3421
3422    def test_sess_cache_no_internal_lookup(self):
3423        """
3424        The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_LOOKUP` 0x100,
3425        the value of `SSL_SESS_CACHE_NO_INTERNAL_LOOKUP` defined by
3426        `openssl/ssl.h`.
3427        """
3428        assert 0x100 == SESS_CACHE_NO_INTERNAL_LOOKUP
3429
3430    def test_sess_cache_no_internal_store(self):
3431        """
3432        The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_STORE` 0x200,
3433        the value of `SSL_SESS_CACHE_NO_INTERNAL_STORE` defined by
3434        `openssl/ssl.h`.
3435        """
3436        assert 0x200 == SESS_CACHE_NO_INTERNAL_STORE
3437
3438    def test_sess_cache_no_internal(self):
3439        """
3440        The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL` 0x300, the
3441        value of `SSL_SESS_CACHE_NO_INTERNAL` defined by
3442        `openssl/ssl.h`.
3443        """
3444        assert 0x300 == SESS_CACHE_NO_INTERNAL
3445
3446
3447class TestMemoryBIO(object):
3448    """
3449    Tests for `OpenSSL.SSL.Connection` using a memory BIO.
3450    """
3451
3452    def _server(self, sock):
3453        """
3454        Create a new server-side SSL `Connection` object wrapped around `sock`.
3455        """
3456        # Create the server side Connection.  This is mostly setup boilerplate
3457        # - use TLSv1, use a particular certificate, etc.
3458        server_ctx = Context(SSLv23_METHOD)
3459        server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE)
3460        server_ctx.set_verify(
3461            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE,
3462            verify_cb,
3463        )
3464        server_store = server_ctx.get_cert_store()
3465        server_ctx.use_privatekey(
3466            load_privatekey(FILETYPE_PEM, server_key_pem)
3467        )
3468        server_ctx.use_certificate(
3469            load_certificate(FILETYPE_PEM, server_cert_pem)
3470        )
3471        server_ctx.check_privatekey()
3472        server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
3473        # Here the Connection is actually created.  If None is passed as the
3474        # 2nd parameter, it indicates a memory BIO should be created.
3475        server_conn = Connection(server_ctx, sock)
3476        server_conn.set_accept_state()
3477        return server_conn
3478
3479    def _client(self, sock):
3480        """
3481        Create a new client-side SSL `Connection` object wrapped around `sock`.
3482        """
3483        # Now create the client side Connection.  Similar boilerplate to the
3484        # above.
3485        client_ctx = Context(SSLv23_METHOD)
3486        client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE)
3487        client_ctx.set_verify(
3488            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE,
3489            verify_cb,
3490        )
3491        client_store = client_ctx.get_cert_store()
3492        client_ctx.use_privatekey(
3493            load_privatekey(FILETYPE_PEM, client_key_pem)
3494        )
3495        client_ctx.use_certificate(
3496            load_certificate(FILETYPE_PEM, client_cert_pem)
3497        )
3498        client_ctx.check_privatekey()
3499        client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
3500        client_conn = Connection(client_ctx, sock)
3501        client_conn.set_connect_state()
3502        return client_conn
3503
3504    def test_memory_connect(self):
3505        """
3506        Two `Connection`s which use memory BIOs can be manually connected by
3507        reading from the output of each and writing those bytes to the input of
3508        the other and in this way establish a connection and exchange
3509        application-level bytes with each other.
3510        """
3511        server_conn = self._server(None)
3512        client_conn = self._client(None)
3513
3514        # There should be no key or nonces yet.
3515        assert server_conn.master_key() is None
3516        assert server_conn.client_random() is None
3517        assert server_conn.server_random() is None
3518
3519        # First, the handshake needs to happen.  We'll deliver bytes back and
3520        # forth between the client and server until neither of them feels like
3521        # speaking any more.
3522        assert interact_in_memory(client_conn, server_conn) is None
3523
3524        # Now that the handshake is done, there should be a key and nonces.
3525        assert server_conn.master_key() is not None
3526        assert server_conn.client_random() is not None
3527        assert server_conn.server_random() is not None
3528        assert server_conn.client_random() == client_conn.client_random()
3529        assert server_conn.server_random() == client_conn.server_random()
3530        assert server_conn.client_random() != server_conn.server_random()
3531        assert client_conn.client_random() != client_conn.server_random()
3532
3533        # Export key material for other uses.
3534        cekm = client_conn.export_keying_material(b"LABEL", 32)
3535        sekm = server_conn.export_keying_material(b"LABEL", 32)
3536        assert cekm is not None
3537        assert sekm is not None
3538        assert cekm == sekm
3539        assert len(sekm) == 32
3540
3541        # Export key material for other uses with additional context.
3542        cekmc = client_conn.export_keying_material(b"LABEL", 32, b"CONTEXT")
3543        sekmc = server_conn.export_keying_material(b"LABEL", 32, b"CONTEXT")
3544        assert cekmc is not None
3545        assert sekmc is not None
3546        assert cekmc == sekmc
3547        assert cekmc != cekm
3548        assert sekmc != sekm
3549        # Export with alternate label
3550        cekmt = client_conn.export_keying_material(b"test", 32, b"CONTEXT")
3551        sekmt = server_conn.export_keying_material(b"test", 32, b"CONTEXT")
3552        assert cekmc != cekmt
3553        assert sekmc != sekmt
3554
3555        # Here are the bytes we'll try to send.
3556        important_message = b"One if by land, two if by sea."
3557
3558        server_conn.write(important_message)
3559        assert interact_in_memory(client_conn, server_conn) == (
3560            client_conn,
3561            important_message,
3562        )
3563
3564        client_conn.write(important_message[::-1])
3565        assert interact_in_memory(client_conn, server_conn) == (
3566            server_conn,
3567            important_message[::-1],
3568        )
3569
3570    def test_socket_connect(self):
3571        """
3572        Just like `test_memory_connect` but with an actual socket.
3573
3574        This is primarily to rule out the memory BIO code as the source of any
3575        problems encountered while passing data over a `Connection` (if
3576        this test fails, there must be a problem outside the memory BIO code,
3577        as no memory BIO is involved here).  Even though this isn't a memory
3578        BIO test, it's convenient to have it here.
3579        """
3580        server_conn, client_conn = loopback()
3581
3582        important_message = b"Help me Obi Wan Kenobi, you're my only hope."
3583        client_conn.send(important_message)
3584        msg = server_conn.recv(1024)
3585        assert msg == important_message
3586
3587        # Again in the other direction, just for fun.
3588        important_message = important_message[::-1]
3589        server_conn.send(important_message)
3590        msg = client_conn.recv(1024)
3591        assert msg == important_message
3592
3593    def test_socket_overrides_memory(self):
3594        """
3595        Test that `OpenSSL.SSL.bio_read` and `OpenSSL.SSL.bio_write` don't
3596        work on `OpenSSL.SSL.Connection`() that use sockets.
3597        """
3598        context = Context(SSLv23_METHOD)
3599        client = socket_any_family()
3600        clientSSL = Connection(context, client)
3601        with pytest.raises(TypeError):
3602            clientSSL.bio_read(100)
3603        with pytest.raises(TypeError):
3604            clientSSL.bio_write(b"foo")
3605        with pytest.raises(TypeError):
3606            clientSSL.bio_shutdown()
3607
3608    def test_outgoing_overflow(self):
3609        """
3610        If more bytes than can be written to the memory BIO are passed to
3611        `Connection.send` at once, the number of bytes which were written is
3612        returned and that many bytes from the beginning of the input can be
3613        read from the other end of the connection.
3614        """
3615        server = self._server(None)
3616        client = self._client(None)
3617
3618        interact_in_memory(client, server)
3619
3620        size = 2 ** 15
3621        sent = client.send(b"x" * size)
3622        # Sanity check.  We're trying to test what happens when the entire
3623        # input can't be sent.  If the entire input was sent, this test is
3624        # meaningless.
3625        assert sent < size
3626
3627        receiver, received = interact_in_memory(client, server)
3628        assert receiver is server
3629
3630        # We can rely on all of these bytes being received at once because
3631        # loopback passes 2 ** 16 to recv - more than 2 ** 15.
3632        assert len(received) == sent
3633
3634    def test_shutdown(self):
3635        """
3636        `Connection.bio_shutdown` signals the end of the data stream
3637        from which the `Connection` reads.
3638        """
3639        server = self._server(None)
3640        server.bio_shutdown()
3641        with pytest.raises(Error) as err:
3642            server.recv(1024)
3643        # We don't want WantReadError or ZeroReturnError or anything - it's a
3644        # handshake failure.
3645        assert type(err.value) in [Error, SysCallError]
3646
3647    def test_unexpected_EOF(self):
3648        """
3649        If the connection is lost before an orderly SSL shutdown occurs,
3650        `OpenSSL.SSL.SysCallError` is raised with a message of
3651        "Unexpected EOF".
3652        """
3653        server_conn, client_conn = loopback()
3654        client_conn.sock_shutdown(SHUT_RDWR)
3655        with pytest.raises(SysCallError) as err:
3656            server_conn.recv(1024)
3657        assert err.value.args == (-1, "Unexpected EOF")
3658
3659    def _check_client_ca_list(self, func):
3660        """
3661        Verify the return value of the `get_client_ca_list` method for
3662        server and client connections.
3663
3664        :param func: A function which will be called with the server context
3665            before the client and server are connected to each other.  This
3666            function should specify a list of CAs for the server to send to the
3667            client and return that same list.  The list will be used to verify
3668            that `get_client_ca_list` returns the proper value at
3669            various times.
3670        """
3671        server = self._server(None)
3672        client = self._client(None)
3673        assert client.get_client_ca_list() == []
3674        assert server.get_client_ca_list() == []
3675        ctx = server.get_context()
3676        expected = func(ctx)
3677        assert client.get_client_ca_list() == []
3678        assert server.get_client_ca_list() == expected
3679        interact_in_memory(client, server)
3680        assert client.get_client_ca_list() == expected
3681        assert server.get_client_ca_list() == expected
3682
3683    def test_set_client_ca_list_errors(self):
3684        """
3685        `Context.set_client_ca_list` raises a `TypeError` if called with a
3686        non-list or a list that contains objects other than X509Names.
3687        """
3688        ctx = Context(SSLv23_METHOD)
3689        with pytest.raises(TypeError):
3690            ctx.set_client_ca_list("spam")
3691        with pytest.raises(TypeError):
3692            ctx.set_client_ca_list(["spam"])
3693
3694    def test_set_empty_ca_list(self):
3695        """
3696        If passed an empty list, `Context.set_client_ca_list` configures the
3697        context to send no CA names to the client and, on both the server and
3698        client sides, `Connection.get_client_ca_list` returns an empty list
3699        after the connection is set up.
3700        """
3701
3702        def no_ca(ctx):
3703            ctx.set_client_ca_list([])
3704            return []
3705
3706        self._check_client_ca_list(no_ca)
3707
3708    def test_set_one_ca_list(self):
3709        """
3710        If passed a list containing a single X509Name,
3711        `Context.set_client_ca_list` configures the context to send
3712        that CA name to the client and, on both the server and client sides,
3713        `Connection.get_client_ca_list` returns a list containing that
3714        X509Name after the connection is set up.
3715        """
3716        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3717        cadesc = cacert.get_subject()
3718
3719        def single_ca(ctx):
3720            ctx.set_client_ca_list([cadesc])
3721            return [cadesc]
3722
3723        self._check_client_ca_list(single_ca)
3724
3725    def test_set_multiple_ca_list(self):
3726        """
3727        If passed a list containing multiple X509Name objects,
3728        `Context.set_client_ca_list` configures the context to send
3729        those CA names to the client and, on both the server and client sides,
3730        `Connection.get_client_ca_list` returns a list containing those
3731        X509Names after the connection is set up.
3732        """
3733        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3734        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3735
3736        sedesc = secert.get_subject()
3737        cldesc = clcert.get_subject()
3738
3739        def multiple_ca(ctx):
3740            L = [sedesc, cldesc]
3741            ctx.set_client_ca_list(L)
3742            return L
3743
3744        self._check_client_ca_list(multiple_ca)
3745
3746    def test_reset_ca_list(self):
3747        """
3748        If called multiple times, only the X509Names passed to the final call
3749        of `Context.set_client_ca_list` are used to configure the CA
3750        names sent to the client.
3751        """
3752        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3753        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3754        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3755
3756        cadesc = cacert.get_subject()
3757        sedesc = secert.get_subject()
3758        cldesc = clcert.get_subject()
3759
3760        def changed_ca(ctx):
3761            ctx.set_client_ca_list([sedesc, cldesc])
3762            ctx.set_client_ca_list([cadesc])
3763            return [cadesc]
3764
3765        self._check_client_ca_list(changed_ca)
3766
3767    def test_mutated_ca_list(self):
3768        """
3769        If the list passed to `Context.set_client_ca_list` is mutated
3770        afterwards, this does not affect the list of CA names sent to the
3771        client.
3772        """
3773        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3774        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3775
3776        cadesc = cacert.get_subject()
3777        sedesc = secert.get_subject()
3778
3779        def mutated_ca(ctx):
3780            L = [cadesc]
3781            ctx.set_client_ca_list([cadesc])
3782            L.append(sedesc)
3783            return [cadesc]
3784
3785        self._check_client_ca_list(mutated_ca)
3786
3787    def test_add_client_ca_wrong_args(self):
3788        """
3789        `Context.add_client_ca` raises `TypeError` if called with
3790        a non-X509 object.
3791        """
3792        ctx = Context(SSLv23_METHOD)
3793        with pytest.raises(TypeError):
3794            ctx.add_client_ca("spam")
3795
3796    def test_one_add_client_ca(self):
3797        """
3798        A certificate's subject can be added as a CA to be sent to the client
3799        with `Context.add_client_ca`.
3800        """
3801        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3802        cadesc = cacert.get_subject()
3803
3804        def single_ca(ctx):
3805            ctx.add_client_ca(cacert)
3806            return [cadesc]
3807
3808        self._check_client_ca_list(single_ca)
3809
3810    def test_multiple_add_client_ca(self):
3811        """
3812        Multiple CA names can be sent to the client by calling
3813        `Context.add_client_ca` with multiple X509 objects.
3814        """
3815        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3816        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3817
3818        cadesc = cacert.get_subject()
3819        sedesc = secert.get_subject()
3820
3821        def multiple_ca(ctx):
3822            ctx.add_client_ca(cacert)
3823            ctx.add_client_ca(secert)
3824            return [cadesc, sedesc]
3825
3826        self._check_client_ca_list(multiple_ca)
3827
3828    def test_set_and_add_client_ca(self):
3829        """
3830        A call to `Context.set_client_ca_list` followed by a call to
3831        `Context.add_client_ca` results in using the CA names from the
3832        first call and the CA name from the second call.
3833        """
3834        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3835        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3836        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3837
3838        cadesc = cacert.get_subject()
3839        sedesc = secert.get_subject()
3840        cldesc = clcert.get_subject()
3841
3842        def mixed_set_add_ca(ctx):
3843            ctx.set_client_ca_list([cadesc, sedesc])
3844            ctx.add_client_ca(clcert)
3845            return [cadesc, sedesc, cldesc]
3846
3847        self._check_client_ca_list(mixed_set_add_ca)
3848
3849    def test_set_after_add_client_ca(self):
3850        """
3851        A call to `Context.set_client_ca_list` after a call to
3852        `Context.add_client_ca` replaces the CA name specified by the
3853        former call with the names specified by the latter call.
3854        """
3855        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3856        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3857        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3858
3859        cadesc = cacert.get_subject()
3860        sedesc = secert.get_subject()
3861
3862        def set_replaces_add_ca(ctx):
3863            ctx.add_client_ca(clcert)
3864            ctx.set_client_ca_list([cadesc])
3865            ctx.add_client_ca(secert)
3866            return [cadesc, sedesc]
3867
3868        self._check_client_ca_list(set_replaces_add_ca)
3869
3870
3871class TestInfoConstants(object):
3872    """
3873    Tests for assorted constants exposed for use in info callbacks.
3874    """
3875
3876    def test_integers(self):
3877        """
3878        All of the info constants are integers.
3879
3880        This is a very weak test.  It would be nice to have one that actually
3881        verifies that as certain info events happen, the value passed to the
3882        info callback matches up with the constant exposed by OpenSSL.SSL.
3883        """
3884        for const in [
3885            SSL_ST_CONNECT,
3886            SSL_ST_ACCEPT,
3887            SSL_ST_MASK,
3888            SSL_CB_LOOP,
3889            SSL_CB_EXIT,
3890            SSL_CB_READ,
3891            SSL_CB_WRITE,
3892            SSL_CB_ALERT,
3893            SSL_CB_READ_ALERT,
3894            SSL_CB_WRITE_ALERT,
3895            SSL_CB_ACCEPT_LOOP,
3896            SSL_CB_ACCEPT_EXIT,
3897            SSL_CB_CONNECT_LOOP,
3898            SSL_CB_CONNECT_EXIT,
3899            SSL_CB_HANDSHAKE_START,
3900            SSL_CB_HANDSHAKE_DONE,
3901        ]:
3902            assert isinstance(const, int)
3903
3904        # These constants don't exist on OpenSSL 1.1.0
3905        for const in [
3906            SSL_ST_INIT,
3907            SSL_ST_BEFORE,
3908            SSL_ST_OK,
3909            SSL_ST_RENEGOTIATE,
3910        ]:
3911            assert const is None or isinstance(const, int)
3912
3913
3914class TestRequires(object):
3915    """
3916    Tests for the decorator factory used to conditionally raise
3917    NotImplementedError when older OpenSSLs are used.
3918    """
3919
3920    def test_available(self):
3921        """
3922        When the OpenSSL functionality is available the decorated functions
3923        work appropriately.
3924        """
3925        feature_guard = _make_requires(True, "Error text")
3926        results = []
3927
3928        @feature_guard
3929        def inner():
3930            results.append(True)
3931            return True
3932
3933        assert inner() is True
3934        assert [True] == results
3935
3936    def test_unavailable(self):
3937        """
3938        When the OpenSSL functionality is not available the decorated function
3939        does not execute and NotImplementedError is raised.
3940        """
3941        feature_guard = _make_requires(False, "Error text")
3942
3943        @feature_guard
3944        def inner():  # pragma: nocover
3945            pytest.fail("Should not be called")
3946
3947        with pytest.raises(NotImplementedError) as e:
3948            inner()
3949
3950        assert "Error text" in str(e.value)
3951
3952
3953class TestOCSP(object):
3954    """
3955    Tests for PyOpenSSL's OCSP stapling support.
3956    """
3957
3958    sample_ocsp_data = b"this is totally ocsp data"
3959
3960    def _client_connection(self, callback, data, request_ocsp=True):
3961        """
3962        Builds a client connection suitable for using OCSP.
3963
3964        :param callback: The callback to register for OCSP.
3965        :param data: The opaque data object that will be handed to the
3966            OCSP callback.
3967        :param request_ocsp: Whether the client will actually ask for OCSP
3968            stapling. Useful for testing only.
3969        """
3970        ctx = Context(SSLv23_METHOD)
3971        ctx.set_ocsp_client_callback(callback, data)
3972        client = Connection(ctx)
3973
3974        if request_ocsp:
3975            client.request_ocsp()
3976
3977        client.set_connect_state()
3978        return client
3979
3980    def _server_connection(self, callback, data):
3981        """
3982        Builds a server connection suitable for using OCSP.
3983
3984        :param callback: The callback to register for OCSP.
3985        :param data: The opaque data object that will be handed to the
3986            OCSP callback.
3987        """
3988        ctx = Context(SSLv23_METHOD)
3989        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
3990        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
3991        ctx.set_ocsp_server_callback(callback, data)
3992        server = Connection(ctx)
3993        server.set_accept_state()
3994        return server
3995
3996    def test_callbacks_arent_called_by_default(self):
3997        """
3998        If both the client and the server have registered OCSP callbacks, but
3999        the client does not send the OCSP request, neither callback gets
4000        called.
4001        """
4002
4003        def ocsp_callback(*args, **kwargs):  # pragma: nocover
4004            pytest.fail("Should not be called")
4005
4006        client = self._client_connection(
4007            callback=ocsp_callback, data=None, request_ocsp=False
4008        )
4009        server = self._server_connection(callback=ocsp_callback, data=None)
4010        handshake_in_memory(client, server)
4011
4012    def test_client_negotiates_without_server(self):
4013        """
4014        If the client wants to do OCSP but the server does not, the handshake
4015        succeeds, and the client callback fires with an empty byte string.
4016        """
4017        called = []
4018
4019        def ocsp_callback(conn, ocsp_data, ignored):
4020            called.append(ocsp_data)
4021            return True
4022
4023        client = self._client_connection(callback=ocsp_callback, data=None)
4024        server = loopback_server_factory(socket=None)
4025        handshake_in_memory(client, server)
4026
4027        assert len(called) == 1
4028        assert called[0] == b""
4029
4030    def test_client_receives_servers_data(self):
4031        """
4032        The data the server sends in its callback is received by the client.
4033        """
4034        calls = []
4035
4036        def server_callback(*args, **kwargs):
4037            return self.sample_ocsp_data
4038
4039        def client_callback(conn, ocsp_data, ignored):
4040            calls.append(ocsp_data)
4041            return True
4042
4043        client = self._client_connection(callback=client_callback, data=None)
4044        server = self._server_connection(callback=server_callback, data=None)
4045        handshake_in_memory(client, server)
4046
4047        assert len(calls) == 1
4048        assert calls[0] == self.sample_ocsp_data
4049
4050    def test_callbacks_are_invoked_with_connections(self):
4051        """
4052        The first arguments to both callbacks are their respective connections.
4053        """
4054        client_calls = []
4055        server_calls = []
4056
4057        def client_callback(conn, *args, **kwargs):
4058            client_calls.append(conn)
4059            return True
4060
4061        def server_callback(conn, *args, **kwargs):
4062            server_calls.append(conn)
4063            return self.sample_ocsp_data
4064
4065        client = self._client_connection(callback=client_callback, data=None)
4066        server = self._server_connection(callback=server_callback, data=None)
4067        handshake_in_memory(client, server)
4068
4069        assert len(client_calls) == 1
4070        assert len(server_calls) == 1
4071        assert client_calls[0] is client
4072        assert server_calls[0] is server
4073
4074    def test_opaque_data_is_passed_through(self):
4075        """
4076        Both callbacks receive an opaque, user-provided piece of data in their
4077        callbacks as the final argument.
4078        """
4079        calls = []
4080
4081        def server_callback(*args):
4082            calls.append(args)
4083            return self.sample_ocsp_data
4084
4085        def client_callback(*args):
4086            calls.append(args)
4087            return True
4088
4089        sentinel = object()
4090
4091        client = self._client_connection(
4092            callback=client_callback, data=sentinel
4093        )
4094        server = self._server_connection(
4095            callback=server_callback, data=sentinel
4096        )
4097        handshake_in_memory(client, server)
4098
4099        assert len(calls) == 2
4100        assert calls[0][-1] is sentinel
4101        assert calls[1][-1] is sentinel
4102
4103    def test_server_returns_empty_string(self):
4104        """
4105        If the server returns an empty bytestring from its callback, the
4106        client callback is called with the empty bytestring.
4107        """
4108        client_calls = []
4109
4110        def server_callback(*args):
4111            return b""
4112
4113        def client_callback(conn, ocsp_data, ignored):
4114            client_calls.append(ocsp_data)
4115            return True
4116
4117        client = self._client_connection(callback=client_callback, data=None)
4118        server = self._server_connection(callback=server_callback, data=None)
4119        handshake_in_memory(client, server)
4120
4121        assert len(client_calls) == 1
4122        assert client_calls[0] == b""
4123
4124    def test_client_returns_false_terminates_handshake(self):
4125        """
4126        If the client returns False from its callback, the handshake fails.
4127        """
4128
4129        def server_callback(*args):
4130            return self.sample_ocsp_data
4131
4132        def client_callback(*args):
4133            return False
4134
4135        client = self._client_connection(callback=client_callback, data=None)
4136        server = self._server_connection(callback=server_callback, data=None)
4137
4138        with pytest.raises(Error):
4139            handshake_in_memory(client, server)
4140
4141    def test_exceptions_in_client_bubble_up(self):
4142        """
4143        The callbacks thrown in the client callback bubble up to the caller.
4144        """
4145
4146        class SentinelException(Exception):
4147            pass
4148
4149        def server_callback(*args):
4150            return self.sample_ocsp_data
4151
4152        def client_callback(*args):
4153            raise SentinelException()
4154
4155        client = self._client_connection(callback=client_callback, data=None)
4156        server = self._server_connection(callback=server_callback, data=None)
4157
4158        with pytest.raises(SentinelException):
4159            handshake_in_memory(client, server)
4160
4161    def test_exceptions_in_server_bubble_up(self):
4162        """
4163        The callbacks thrown in the server callback bubble up to the caller.
4164        """
4165
4166        class SentinelException(Exception):
4167            pass
4168
4169        def server_callback(*args):
4170            raise SentinelException()
4171
4172        def client_callback(*args):  # pragma: nocover
4173            pytest.fail("Should not be called")
4174
4175        client = self._client_connection(callback=client_callback, data=None)
4176        server = self._server_connection(callback=server_callback, data=None)
4177
4178        with pytest.raises(SentinelException):
4179            handshake_in_memory(client, server)
4180
4181    def test_server_must_return_bytes(self):
4182        """
4183        The server callback must return a bytestring, or a TypeError is thrown.
4184        """
4185
4186        def server_callback(*args):
4187            return self.sample_ocsp_data.decode("ascii")
4188
4189        def client_callback(*args):  # pragma: nocover
4190            pytest.fail("Should not be called")
4191
4192        client = self._client_connection(callback=client_callback, data=None)
4193        server = self._server_connection(callback=server_callback, data=None)
4194
4195        with pytest.raises(TypeError):
4196            handshake_in_memory(client, server)
4197