• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (C) Jean-Paul Calderone
2# See LICENSE for details.
3
4"""
5Unit tests for :mod:`OpenSSL.SSL`.
6"""
7
8import datetime
9import sys
10import uuid
11
12from gc import collect, get_referrers
13from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
14from sys import platform, getfilesystemencoding
15from socket import MSG_PEEK, SHUT_RDWR, error, socket
16from os import makedirs
17from os.path import join
18from weakref import ref
19from warnings import simplefilter
20
21import pytest
22
23from pretend import raiser
24
25from six import PY3, text_type
26
27from cryptography import x509
28from cryptography.hazmat.backends import default_backend
29from cryptography.hazmat.primitives import hashes
30from cryptography.hazmat.primitives import serialization
31from cryptography.hazmat.primitives.asymmetric import rsa
32from cryptography.x509.oid import NameOID
33
34
35from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
36from OpenSSL.crypto import PKey, X509, X509Extension, X509Store
37from OpenSSL.crypto import dump_privatekey, load_privatekey
38from OpenSSL.crypto import dump_certificate, load_certificate
39from OpenSSL.crypto import get_elliptic_curves
40
41from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS
42from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON
43from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
44from OpenSSL.SSL import (
45    SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD,
46    TLSv1_1_METHOD, TLSv1_2_METHOD)
47from OpenSSL.SSL import OP_SINGLE_DH_USE, OP_NO_SSLv2, OP_NO_SSLv3
48from OpenSSL.SSL import (
49    VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE)
50
51from OpenSSL import SSL
52from OpenSSL.SSL import (
53    SESS_CACHE_OFF, SESS_CACHE_CLIENT, SESS_CACHE_SERVER, SESS_CACHE_BOTH,
54    SESS_CACHE_NO_AUTO_CLEAR, SESS_CACHE_NO_INTERNAL_LOOKUP,
55    SESS_CACHE_NO_INTERNAL_STORE, SESS_CACHE_NO_INTERNAL)
56
57from OpenSSL.SSL import (
58    Error, SysCallError, WantReadError, WantWriteError, ZeroReturnError)
59from OpenSSL.SSL import (
60    Context, ContextType, Session, Connection, ConnectionType, SSLeay_version)
61from OpenSSL.SSL import _make_requires
62
63from OpenSSL._util import ffi as _ffi, lib as _lib
64
65from OpenSSL.SSL import (
66    OP_NO_QUERY_MTU, OP_COOKIE_EXCHANGE, OP_NO_TICKET, OP_NO_COMPRESSION,
67    MODE_RELEASE_BUFFERS)
68
69from OpenSSL.SSL import (
70    SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK,
71    SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT,
72    SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP,
73    SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT,
74    SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE)
75
76try:
77    from OpenSSL.SSL import (
78        SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE
79    )
80except ImportError:
81    SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None
82
83from .util import WARNING_TYPE_EXPECTED, NON_ASCII, is_consistent_type
84from .test_crypto import (
85    cleartextCertificatePEM, cleartextPrivateKeyPEM,
86    client_cert_pem, client_key_pem, server_cert_pem, server_key_pem,
87    root_cert_pem)
88
89
90# openssl dhparam 1024 -out dh-1024.pem (note that 1024 is a small number of
91# bits to use)
92dhparam = """\
93-----BEGIN DH PARAMETERS-----
94MIGHAoGBALdUMvn+C9MM+y5BWZs11mSeH6HHoEq0UVbzVq7UojC1hbsZUuGukQ3a
95Qh2/pwqb18BZFykrWB0zv/OkLa0kx4cuUgNrUVq1EFheBiX6YqryJ7t2sO09NQiO
96V7H54LmltOT/hEh6QWsJqb6BQgH65bswvV/XkYGja8/T0GzvbaVzAgEC
97-----END DH PARAMETERS-----
98"""
99
100
101skip_if_py3 = pytest.mark.skipif(PY3, reason="Python 2 only")
102
103
104def join_bytes_or_unicode(prefix, suffix):
105    """
106    Join two path components of either ``bytes`` or ``unicode``.
107
108    The return type is the same as the type of ``prefix``.
109    """
110    # If the types are the same, nothing special is necessary.
111    if type(prefix) == type(suffix):
112        return join(prefix, suffix)
113
114    # Otherwise, coerce suffix to the type of prefix.
115    if isinstance(prefix, text_type):
116        return join(prefix, suffix.decode(getfilesystemencoding()))
117    else:
118        return join(prefix, suffix.encode(getfilesystemencoding()))
119
120
121def verify_cb(conn, cert, errnum, depth, ok):
122    return ok
123
124
125def socket_pair():
126    """
127    Establish and return a pair of network sockets connected to each other.
128    """
129    # Connect a pair of sockets
130    port = socket()
131    port.bind(('', 0))
132    port.listen(1)
133    client = socket()
134    client.setblocking(False)
135    client.connect_ex(("127.0.0.1", port.getsockname()[1]))
136    client.setblocking(True)
137    server = port.accept()[0]
138
139    # Let's pass some unencrypted data to make sure our socket connection is
140    # fine.  Just one byte, so we don't have to worry about buffers getting
141    # filled up or fragmentation.
142    server.send(b"x")
143    assert client.recv(1024) == b"x"
144    client.send(b"y")
145    assert server.recv(1024) == b"y"
146
147    # Most of our callers want non-blocking sockets, make it easy for them.
148    server.setblocking(False)
149    client.setblocking(False)
150
151    return (server, client)
152
153
154def handshake(client, server):
155    conns = [client, server]
156    while conns:
157        for conn in conns:
158            try:
159                conn.do_handshake()
160            except WantReadError:
161                pass
162            else:
163                conns.remove(conn)
164
165
166def _create_certificate_chain():
167    """
168    Construct and return a chain of certificates.
169
170        1. A new self-signed certificate authority certificate (cacert)
171        2. A new intermediate certificate signed by cacert (icert)
172        3. A new server certificate signed by icert (scert)
173    """
174    caext = X509Extension(b'basicConstraints', False, b'CA:true')
175
176    # Step 1
177    cakey = PKey()
178    cakey.generate_key(TYPE_RSA, 1024)
179    cacert = X509()
180    cacert.get_subject().commonName = "Authority Certificate"
181    cacert.set_issuer(cacert.get_subject())
182    cacert.set_pubkey(cakey)
183    cacert.set_notBefore(b"20000101000000Z")
184    cacert.set_notAfter(b"20200101000000Z")
185    cacert.add_extensions([caext])
186    cacert.set_serial_number(0)
187    cacert.sign(cakey, "sha1")
188
189    # Step 2
190    ikey = PKey()
191    ikey.generate_key(TYPE_RSA, 1024)
192    icert = X509()
193    icert.get_subject().commonName = "Intermediate Certificate"
194    icert.set_issuer(cacert.get_subject())
195    icert.set_pubkey(ikey)
196    icert.set_notBefore(b"20000101000000Z")
197    icert.set_notAfter(b"20200101000000Z")
198    icert.add_extensions([caext])
199    icert.set_serial_number(0)
200    icert.sign(cakey, "sha1")
201
202    # Step 3
203    skey = PKey()
204    skey.generate_key(TYPE_RSA, 1024)
205    scert = X509()
206    scert.get_subject().commonName = "Server Certificate"
207    scert.set_issuer(icert.get_subject())
208    scert.set_pubkey(skey)
209    scert.set_notBefore(b"20000101000000Z")
210    scert.set_notAfter(b"20200101000000Z")
211    scert.add_extensions([
212        X509Extension(b'basicConstraints', True, b'CA:false')])
213    scert.set_serial_number(0)
214    scert.sign(ikey, "sha1")
215
216    return [(cakey, cacert), (ikey, icert), (skey, scert)]
217
218
219def loopback_client_factory(socket, version=SSLv23_METHOD):
220    client = Connection(Context(version), socket)
221    client.set_connect_state()
222    return client
223
224
225def loopback_server_factory(socket, version=SSLv23_METHOD):
226    ctx = Context(version)
227    ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
228    ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
229    server = Connection(ctx, socket)
230    server.set_accept_state()
231    return server
232
233
234def loopback(server_factory=None, client_factory=None):
235    """
236    Create a connected socket pair and force two connected SSL sockets
237    to talk to each other via memory BIOs.
238    """
239    if server_factory is None:
240        server_factory = loopback_server_factory
241    if client_factory is None:
242        client_factory = loopback_client_factory
243
244    (server, client) = socket_pair()
245    server = server_factory(server)
246    client = client_factory(client)
247
248    handshake(client, server)
249
250    server.setblocking(True)
251    client.setblocking(True)
252    return server, client
253
254
255def interact_in_memory(client_conn, server_conn):
256    """
257    Try to read application bytes from each of the two `Connection` objects.
258    Copy bytes back and forth between their send/receive buffers for as long
259    as there is anything to copy.  When there is nothing more to copy,
260    return `None`.  If one of them actually manages to deliver some application
261    bytes, return a two-tuple of the connection from which the bytes were read
262    and the bytes themselves.
263    """
264    wrote = True
265    while wrote:
266        # Loop until neither side has anything to say
267        wrote = False
268
269        # Copy stuff from each side's send buffer to the other side's
270        # receive buffer.
271        for (read, write) in [(client_conn, server_conn),
272                              (server_conn, client_conn)]:
273
274            # Give the side a chance to generate some more bytes, or succeed.
275            try:
276                data = read.recv(2 ** 16)
277            except WantReadError:
278                # It didn't succeed, so we'll hope it generated some output.
279                pass
280            else:
281                # It did succeed, so we'll stop now and let the caller deal
282                # with it.
283                return (read, data)
284
285            while True:
286                # Keep copying as long as there's more stuff there.
287                try:
288                    dirty = read.bio_read(4096)
289                except WantReadError:
290                    # Okay, nothing more waiting to be sent.  Stop
291                    # processing this send buffer.
292                    break
293                else:
294                    # Keep track of the fact that someone generated some
295                    # output.
296                    wrote = True
297                    write.bio_write(dirty)
298
299
300def handshake_in_memory(client_conn, server_conn):
301    """
302    Perform the TLS handshake between two `Connection` instances connected to
303    each other via memory BIOs.
304    """
305    client_conn.set_connect_state()
306    server_conn.set_accept_state()
307
308    for conn in [client_conn, server_conn]:
309        try:
310            conn.do_handshake()
311        except WantReadError:
312            pass
313
314    interact_in_memory(client_conn, server_conn)
315
316
317class TestVersion(object):
318    """
319    Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and
320    `OpenSSL.SSL.OPENSSL_VERSION_NUMBER`.
321    """
322    def test_OPENSSL_VERSION_NUMBER(self):
323        """
324        `OPENSSL_VERSION_NUMBER` is an integer with status in the low byte and
325        the patch, fix, minor, and major versions in the nibbles above that.
326        """
327        assert isinstance(OPENSSL_VERSION_NUMBER, int)
328
329    def test_SSLeay_version(self):
330        """
331        `SSLeay_version` takes a version type indicator and returns one of a
332        number of version strings based on that indicator.
333        """
334        versions = {}
335        for t in [SSLEAY_VERSION, SSLEAY_CFLAGS, SSLEAY_BUILT_ON,
336                  SSLEAY_PLATFORM, SSLEAY_DIR]:
337            version = SSLeay_version(t)
338            versions[version] = t
339            assert isinstance(version, bytes)
340        assert len(versions) == 5
341
342
343@pytest.fixture
344def ca_file(tmpdir):
345    """
346    Create a valid PEM file with CA certificates and return the path.
347    """
348    key = rsa.generate_private_key(
349        public_exponent=65537,
350        key_size=2048,
351        backend=default_backend()
352    )
353    public_key = key.public_key()
354
355    builder = x509.CertificateBuilder()
356    builder = builder.subject_name(x509.Name([
357        x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"),
358    ]))
359    builder = builder.issuer_name(x509.Name([
360        x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"),
361    ]))
362    one_day = datetime.timedelta(1, 0, 0)
363    builder = builder.not_valid_before(datetime.datetime.today() - one_day)
364    builder = builder.not_valid_after(datetime.datetime.today() + one_day)
365    builder = builder.serial_number(int(uuid.uuid4()))
366    builder = builder.public_key(public_key)
367    builder = builder.add_extension(
368        x509.BasicConstraints(ca=True, path_length=None), critical=True,
369    )
370
371    certificate = builder.sign(
372        private_key=key, algorithm=hashes.SHA256(),
373        backend=default_backend()
374    )
375
376    ca_file = tmpdir.join("test.pem")
377    ca_file.write_binary(
378        certificate.public_bytes(
379            encoding=serialization.Encoding.PEM,
380        )
381    )
382
383    return str(ca_file).encode("ascii")
384
385
386@pytest.fixture
387def context():
388    """
389    A simple TLS 1.0 context.
390    """
391    return Context(TLSv1_METHOD)
392
393
394class TestContext(object):
395    """
396    Unit tests for `OpenSSL.SSL.Context`.
397    """
398    @pytest.mark.parametrize("cipher_string", [
399        b"hello world:AES128-SHA",
400        u"hello world:AES128-SHA",
401    ])
402    def test_set_cipher_list(self, context, cipher_string):
403        """
404        `Context.set_cipher_list` accepts both byte and unicode strings
405        for naming the ciphers which connections created with the context
406        object will be able to choose from.
407        """
408        context.set_cipher_list(cipher_string)
409        conn = Connection(context, None)
410
411        assert "AES128-SHA" in conn.get_cipher_list()
412
413    @pytest.mark.parametrize("cipher_list,error", [
414        (object(), TypeError),
415        ("imaginary-cipher", Error),
416    ])
417    def test_set_cipher_list_wrong_args(self, context, cipher_list, error):
418        """
419        `Context.set_cipher_list` raises `TypeError` when passed a non-string
420        argument and raises `OpenSSL.SSL.Error` when passed an incorrect cipher
421        list string.
422        """
423        with pytest.raises(error):
424            context.set_cipher_list(cipher_list)
425
426    def test_load_client_ca(self, context, ca_file):
427        """
428        `Context.load_client_ca` works as far as we can tell.
429        """
430        context.load_client_ca(ca_file)
431
432    def test_load_client_ca_invalid(self, context, tmpdir):
433        """
434        `Context.load_client_ca` raises an Error if the ca file is invalid.
435        """
436        ca_file = tmpdir.join("test.pem")
437        ca_file.write("")
438
439        with pytest.raises(Error) as e:
440            context.load_client_ca(str(ca_file).encode("ascii"))
441
442        assert "PEM routines" == e.value.args[0][0][0]
443
444    def test_load_client_ca_unicode(self, context, ca_file):
445        """
446        Passing the path as unicode raises a warning but works.
447        """
448        pytest.deprecated_call(
449            context.load_client_ca, ca_file.decode("ascii")
450        )
451
452    def test_set_session_id(self, context):
453        """
454        `Context.set_session_id` works as far as we can tell.
455        """
456        context.set_session_id(b"abc")
457
458    def test_set_session_id_fail(self, context):
459        """
460        `Context.set_session_id` errors are propagated.
461        """
462        with pytest.raises(Error) as e:
463            context.set_session_id(b"abc" * 1000)
464
465        assert [
466            ("SSL routines",
467             "SSL_CTX_set_session_id_context",
468             "ssl session id context too long")
469        ] == e.value.args[0]
470
471    def test_set_session_id_unicode(self, context):
472        """
473        `Context.set_session_id` raises a warning if a unicode string is
474        passed.
475        """
476        pytest.deprecated_call(context.set_session_id, u"abc")
477
478    def test_method(self):
479        """
480        `Context` can be instantiated with one of `SSLv2_METHOD`,
481        `SSLv3_METHOD`, `SSLv23_METHOD`, `TLSv1_METHOD`, `TLSv1_1_METHOD`,
482        or `TLSv1_2_METHOD`.
483        """
484        methods = [SSLv23_METHOD, TLSv1_METHOD]
485        for meth in methods:
486            Context(meth)
487
488        maybe = [SSLv2_METHOD, SSLv3_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD]
489        for meth in maybe:
490            try:
491                Context(meth)
492            except (Error, ValueError):
493                # Some versions of OpenSSL have SSLv2 / TLSv1.1 / TLSv1.2, some
494                # don't.  Difficult to say in advance.
495                pass
496
497        with pytest.raises(TypeError):
498            Context("")
499        with pytest.raises(ValueError):
500            Context(10)
501
502    @skip_if_py3
503    def test_method_long(self):
504        """
505        On Python 2 `Context` accepts values of type `long` as well as `int`.
506        """
507        Context(long(TLSv1_METHOD))
508
509    def test_type(self):
510        """
511        `Context` and `ContextType` refer to the same type object and can
512        be used to create instances of that type.
513        """
514        assert Context is ContextType
515        assert is_consistent_type(Context, 'Context', TLSv1_METHOD)
516
517    def test_use_privatekey(self):
518        """
519        `Context.use_privatekey` takes an `OpenSSL.crypto.PKey` instance.
520        """
521        key = PKey()
522        key.generate_key(TYPE_RSA, 512)
523        ctx = Context(TLSv1_METHOD)
524        ctx.use_privatekey(key)
525        with pytest.raises(TypeError):
526            ctx.use_privatekey("")
527
528    def test_use_privatekey_file_missing(self, tmpfile):
529        """
530        `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` when passed
531        the name of a file which does not exist.
532        """
533        ctx = Context(TLSv1_METHOD)
534        with pytest.raises(Error):
535            ctx.use_privatekey_file(tmpfile)
536
537    def _use_privatekey_file_test(self, pemfile, filetype):
538        """
539        Verify that calling ``Context.use_privatekey_file`` with the given
540        arguments does not raise an exception.
541        """
542        key = PKey()
543        key.generate_key(TYPE_RSA, 512)
544
545        with open(pemfile, "wt") as pem:
546            pem.write(
547                dump_privatekey(FILETYPE_PEM, key).decode("ascii")
548            )
549
550        ctx = Context(TLSv1_METHOD)
551        ctx.use_privatekey_file(pemfile, filetype)
552
553    @pytest.mark.parametrize('filetype', [object(), "", None, 1.0])
554    def test_wrong_privatekey_file_wrong_args(self, tmpfile, filetype):
555        """
556        `Context.use_privatekey_file` raises `TypeError` when called with
557        a `filetype` which is not a valid file encoding.
558        """
559        ctx = Context(TLSv1_METHOD)
560        with pytest.raises(TypeError):
561            ctx.use_privatekey_file(tmpfile, filetype)
562
563    def test_use_privatekey_file_bytes(self, tmpfile):
564        """
565        A private key can be specified from a file by passing a ``bytes``
566        instance giving the file name to ``Context.use_privatekey_file``.
567        """
568        self._use_privatekey_file_test(
569            tmpfile + NON_ASCII.encode(getfilesystemencoding()),
570            FILETYPE_PEM,
571        )
572
573    def test_use_privatekey_file_unicode(self, tmpfile):
574        """
575        A private key can be specified from a file by passing a ``unicode``
576        instance giving the file name to ``Context.use_privatekey_file``.
577        """
578        self._use_privatekey_file_test(
579            tmpfile.decode(getfilesystemencoding()) + NON_ASCII,
580            FILETYPE_PEM,
581        )
582
583    @skip_if_py3
584    def test_use_privatekey_file_long(self, tmpfile):
585        """
586        On Python 2 `Context.use_privatekey_file` accepts a filetype of
587        type `long` as well as `int`.
588        """
589        self._use_privatekey_file_test(tmpfile, long(FILETYPE_PEM))
590
591    def test_use_certificate_wrong_args(self):
592        """
593        `Context.use_certificate_wrong_args` raises `TypeError` when not passed
594        exactly one `OpenSSL.crypto.X509` instance as an argument.
595        """
596        ctx = Context(TLSv1_METHOD)
597        with pytest.raises(TypeError):
598            ctx.use_certificate("hello, world")
599
600    def test_use_certificate_uninitialized(self):
601        """
602        `Context.use_certificate` raises `OpenSSL.SSL.Error` when passed a
603        `OpenSSL.crypto.X509` instance which has not been initialized
604        (ie, which does not actually have any certificate data).
605        """
606        ctx = Context(TLSv1_METHOD)
607        with pytest.raises(Error):
608            ctx.use_certificate(X509())
609
610    def test_use_certificate(self):
611        """
612        `Context.use_certificate` sets the certificate which will be
613        used to identify connections created using the context.
614        """
615        # TODO
616        # Hard to assert anything.  But we could set a privatekey then ask
617        # OpenSSL if the cert and key agree using check_privatekey.  Then as
618        # long as check_privatekey works right we're good...
619        ctx = Context(TLSv1_METHOD)
620        ctx.use_certificate(
621            load_certificate(FILETYPE_PEM, cleartextCertificatePEM)
622        )
623
624    def test_use_certificate_file_wrong_args(self):
625        """
626        `Context.use_certificate_file` raises `TypeError` if the first
627        argument is not a byte string or the second argument is not an integer.
628        """
629        ctx = Context(TLSv1_METHOD)
630        with pytest.raises(TypeError):
631            ctx.use_certificate_file(object(), FILETYPE_PEM)
632        with pytest.raises(TypeError):
633            ctx.use_certificate_file(b"somefile", object())
634        with pytest.raises(TypeError):
635            ctx.use_certificate_file(object(), FILETYPE_PEM)
636
637    def test_use_certificate_file_missing(self, tmpfile):
638        """
639        `Context.use_certificate_file` raises `OpenSSL.SSL.Error` if passed
640        the name of a file which does not exist.
641        """
642        ctx = Context(TLSv1_METHOD)
643        with pytest.raises(Error):
644            ctx.use_certificate_file(tmpfile)
645
646    def _use_certificate_file_test(self, certificate_file):
647        """
648        Verify that calling ``Context.use_certificate_file`` with the given
649        filename doesn't raise an exception.
650        """
651        # TODO
652        # Hard to assert anything.  But we could set a privatekey then ask
653        # OpenSSL if the cert and key agree using check_privatekey.  Then as
654        # long as check_privatekey works right we're good...
655        with open(certificate_file, "wb") as pem_file:
656            pem_file.write(cleartextCertificatePEM)
657
658        ctx = Context(TLSv1_METHOD)
659        ctx.use_certificate_file(certificate_file)
660
661    def test_use_certificate_file_bytes(self, tmpfile):
662        """
663        `Context.use_certificate_file` sets the certificate (given as a
664        `bytes` filename) which will be used to identify connections created
665        using the context.
666        """
667        filename = tmpfile + NON_ASCII.encode(getfilesystemencoding())
668        self._use_certificate_file_test(filename)
669
670    def test_use_certificate_file_unicode(self, tmpfile):
671        """
672        `Context.use_certificate_file` sets the certificate (given as a
673        `bytes` filename) which will be used to identify connections created
674        using the context.
675        """
676        filename = tmpfile.decode(getfilesystemencoding()) + NON_ASCII
677        self._use_certificate_file_test(filename)
678
679    @skip_if_py3
680    def test_use_certificate_file_long(self, tmpfile):
681        """
682        On Python 2 `Context.use_certificate_file` accepts a
683        filetype of type `long` as well as `int`.
684        """
685        pem_filename = tmpfile
686        with open(pem_filename, "wb") as pem_file:
687            pem_file.write(cleartextCertificatePEM)
688
689        ctx = Context(TLSv1_METHOD)
690        ctx.use_certificate_file(pem_filename, long(FILETYPE_PEM))
691
692    def test_check_privatekey_valid(self):
693        """
694        `Context.check_privatekey` returns `None` if the `Context` instance
695        has been configured to use a matched key and certificate pair.
696        """
697        key = load_privatekey(FILETYPE_PEM, client_key_pem)
698        cert = load_certificate(FILETYPE_PEM, client_cert_pem)
699        context = Context(TLSv1_METHOD)
700        context.use_privatekey(key)
701        context.use_certificate(cert)
702        assert None is context.check_privatekey()
703
704    def test_check_privatekey_invalid(self):
705        """
706        `Context.check_privatekey` raises `Error` if the `Context` instance
707        has been configured to use a key and certificate pair which don't
708        relate to each other.
709        """
710        key = load_privatekey(FILETYPE_PEM, client_key_pem)
711        cert = load_certificate(FILETYPE_PEM, server_cert_pem)
712        context = Context(TLSv1_METHOD)
713        context.use_privatekey(key)
714        context.use_certificate(cert)
715        with pytest.raises(Error):
716            context.check_privatekey()
717
718    def test_app_data(self):
719        """
720        `Context.set_app_data` stores an object for later retrieval
721        using `Context.get_app_data`.
722        """
723        app_data = object()
724        context = Context(TLSv1_METHOD)
725        context.set_app_data(app_data)
726        assert context.get_app_data() is app_data
727
728    def test_set_options_wrong_args(self):
729        """
730        `Context.set_options` raises `TypeError` if called with
731        a non-`int` argument.
732        """
733        context = Context(TLSv1_METHOD)
734        with pytest.raises(TypeError):
735            context.set_options(None)
736
737    def test_set_options(self):
738        """
739        `Context.set_options` returns the new options value.
740        """
741        context = Context(TLSv1_METHOD)
742        options = context.set_options(OP_NO_SSLv2)
743        assert options & OP_NO_SSLv2 == OP_NO_SSLv2
744
745    @skip_if_py3
746    def test_set_options_long(self):
747        """
748        On Python 2 `Context.set_options` accepts values of type
749        `long` as well as `int`.
750        """
751        context = Context(TLSv1_METHOD)
752        options = context.set_options(long(OP_NO_SSLv2))
753        assert options & OP_NO_SSLv2 == OP_NO_SSLv2
754
755    def test_set_mode_wrong_args(self):
756        """
757        `Context.set_mode` raises `TypeError` if called with
758        a non-`int` argument.
759        """
760        context = Context(TLSv1_METHOD)
761        with pytest.raises(TypeError):
762            context.set_mode(None)
763
764    def test_set_mode(self):
765        """
766        `Context.set_mode` accepts a mode bitvector and returns the
767        newly set mode.
768        """
769        context = Context(TLSv1_METHOD)
770        assert MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS)
771
772    @skip_if_py3
773    def test_set_mode_long(self):
774        """
775        On Python 2 `Context.set_mode` accepts values of type `long` as well
776        as `int`.
777        """
778        context = Context(TLSv1_METHOD)
779        mode = context.set_mode(long(MODE_RELEASE_BUFFERS))
780        assert MODE_RELEASE_BUFFERS & mode
781
782    def test_set_timeout_wrong_args(self):
783        """
784        `Context.set_timeout` raises `TypeError` if called with
785        a non-`int` argument.
786        """
787        context = Context(TLSv1_METHOD)
788        with pytest.raises(TypeError):
789            context.set_timeout(None)
790
791    def test_timeout(self):
792        """
793        `Context.set_timeout` sets the session timeout for all connections
794        created using the context object. `Context.get_timeout` retrieves
795        this value.
796        """
797        context = Context(TLSv1_METHOD)
798        context.set_timeout(1234)
799        assert context.get_timeout() == 1234
800
801    @skip_if_py3
802    def test_timeout_long(self):
803        """
804        On Python 2 `Context.set_timeout` accepts values of type `long` as
805        well as int.
806        """
807        context = Context(TLSv1_METHOD)
808        context.set_timeout(long(1234))
809        assert context.get_timeout() == 1234
810
811    def test_set_verify_depth_wrong_args(self):
812        """
813        `Context.set_verify_depth` raises `TypeError` if called with a
814        non-`int` argument.
815        """
816        context = Context(TLSv1_METHOD)
817        with pytest.raises(TypeError):
818            context.set_verify_depth(None)
819
820    def test_verify_depth(self):
821        """
822        `Context.set_verify_depth` sets the number of certificates in
823        a chain to follow before giving up.  The value can be retrieved with
824        `Context.get_verify_depth`.
825        """
826        context = Context(TLSv1_METHOD)
827        context.set_verify_depth(11)
828        assert context.get_verify_depth() == 11
829
830    @skip_if_py3
831    def test_verify_depth_long(self):
832        """
833        On Python 2 `Context.set_verify_depth` accepts values of type `long`
834        as well as int.
835        """
836        context = Context(TLSv1_METHOD)
837        context.set_verify_depth(long(11))
838        assert context.get_verify_depth() == 11
839
840    def _write_encrypted_pem(self, passphrase, tmpfile):
841        """
842        Write a new private key out to a new file, encrypted using the given
843        passphrase.  Return the path to the new file.
844        """
845        key = PKey()
846        key.generate_key(TYPE_RSA, 512)
847        pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase)
848        with open(tmpfile, 'w') as fObj:
849            fObj.write(pem.decode('ascii'))
850        return tmpfile
851
852    def test_set_passwd_cb_wrong_args(self):
853        """
854        `Context.set_passwd_cb` raises `TypeError` if called with a
855        non-callable first argument.
856        """
857        context = Context(TLSv1_METHOD)
858        with pytest.raises(TypeError):
859            context.set_passwd_cb(None)
860
861    def test_set_passwd_cb(self, tmpfile):
862        """
863        `Context.set_passwd_cb` accepts a callable which will be invoked when
864        a private key is loaded from an encrypted PEM.
865        """
866        passphrase = b"foobar"
867        pemFile = self._write_encrypted_pem(passphrase, tmpfile)
868        calledWith = []
869
870        def passphraseCallback(maxlen, verify, extra):
871            calledWith.append((maxlen, verify, extra))
872            return passphrase
873        context = Context(TLSv1_METHOD)
874        context.set_passwd_cb(passphraseCallback)
875        context.use_privatekey_file(pemFile)
876        assert len(calledWith) == 1
877        assert isinstance(calledWith[0][0], int)
878        assert isinstance(calledWith[0][1], int)
879        assert calledWith[0][2] is None
880
881    def test_passwd_callback_exception(self, tmpfile):
882        """
883        `Context.use_privatekey_file` propagates any exception raised
884        by the passphrase callback.
885        """
886        pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
887
888        def passphraseCallback(maxlen, verify, extra):
889            raise RuntimeError("Sorry, I am a fail.")
890
891        context = Context(TLSv1_METHOD)
892        context.set_passwd_cb(passphraseCallback)
893        with pytest.raises(RuntimeError):
894            context.use_privatekey_file(pemFile)
895
896    def test_passwd_callback_false(self, tmpfile):
897        """
898        `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the
899        passphrase callback returns a false value.
900        """
901        pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
902
903        def passphraseCallback(maxlen, verify, extra):
904            return b""
905
906        context = Context(TLSv1_METHOD)
907        context.set_passwd_cb(passphraseCallback)
908        with pytest.raises(Error):
909            context.use_privatekey_file(pemFile)
910
911    def test_passwd_callback_non_string(self, tmpfile):
912        """
913        `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the
914        passphrase callback returns a true non-string value.
915        """
916        pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile)
917
918        def passphraseCallback(maxlen, verify, extra):
919            return 10
920
921        context = Context(TLSv1_METHOD)
922        context.set_passwd_cb(passphraseCallback)
923        # TODO: Surely this is the wrong error?
924        with pytest.raises(ValueError):
925            context.use_privatekey_file(pemFile)
926
927    def test_passwd_callback_too_long(self, tmpfile):
928        """
929        If the passphrase returned by the passphrase callback returns a string
930        longer than the indicated maximum length, it is truncated.
931        """
932        # A priori knowledge!
933        passphrase = b"x" * 1024
934        pemFile = self._write_encrypted_pem(passphrase, tmpfile)
935
936        def passphraseCallback(maxlen, verify, extra):
937            assert maxlen == 1024
938            return passphrase + b"y"
939
940        context = Context(TLSv1_METHOD)
941        context.set_passwd_cb(passphraseCallback)
942        # This shall succeed because the truncated result is the correct
943        # passphrase.
944        context.use_privatekey_file(pemFile)
945
946    def test_set_info_callback(self):
947        """
948        `Context.set_info_callback` accepts a callable which will be
949        invoked when certain information about an SSL connection is available.
950        """
951        (server, client) = socket_pair()
952
953        clientSSL = Connection(Context(TLSv1_METHOD), client)
954        clientSSL.set_connect_state()
955
956        called = []
957
958        def info(conn, where, ret):
959            called.append((conn, where, ret))
960        context = Context(TLSv1_METHOD)
961        context.set_info_callback(info)
962        context.use_certificate(
963            load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
964        context.use_privatekey(
965            load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
966
967        serverSSL = Connection(context, server)
968        serverSSL.set_accept_state()
969
970        handshake(clientSSL, serverSSL)
971
972        # The callback must always be called with a Connection instance as the
973        # first argument.  It would probably be better to split this into
974        # separate tests for client and server side info callbacks so we could
975        # assert it is called with the right Connection instance.  It would
976        # also be good to assert *something* about `where` and `ret`.
977        notConnections = [
978            conn for (conn, where, ret) in called
979            if not isinstance(conn, Connection)]
980        assert [] == notConnections, (
981            "Some info callback arguments were not Connection instances.")
982
983    def _load_verify_locations_test(self, *args):
984        """
985        Create a client context which will verify the peer certificate and call
986        its `load_verify_locations` method with the given arguments.
987        Then connect it to a server and ensure that the handshake succeeds.
988        """
989        (server, client) = socket_pair()
990
991        clientContext = Context(TLSv1_METHOD)
992        clientContext.load_verify_locations(*args)
993        # Require that the server certificate verify properly or the
994        # connection will fail.
995        clientContext.set_verify(
996            VERIFY_PEER,
997            lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
998
999        clientSSL = Connection(clientContext, client)
1000        clientSSL.set_connect_state()
1001
1002        serverContext = Context(TLSv1_METHOD)
1003        serverContext.use_certificate(
1004            load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
1005        serverContext.use_privatekey(
1006            load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
1007
1008        serverSSL = Connection(serverContext, server)
1009        serverSSL.set_accept_state()
1010
1011        # Without load_verify_locations above, the handshake
1012        # will fail:
1013        # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE',
1014        #          'certificate verify failed')]
1015        handshake(clientSSL, serverSSL)
1016
1017        cert = clientSSL.get_peer_certificate()
1018        assert cert.get_subject().CN == 'Testing Root CA'
1019
1020    def _load_verify_cafile(self, cafile):
1021        """
1022        Verify that if path to a file containing a certificate is passed to
1023        `Context.load_verify_locations` for the ``cafile`` parameter, that
1024        certificate is used as a trust root for the purposes of verifying
1025        connections created using that `Context`.
1026        """
1027        with open(cafile, 'w') as fObj:
1028            fObj.write(cleartextCertificatePEM.decode('ascii'))
1029
1030        self._load_verify_locations_test(cafile)
1031
1032    def test_load_verify_bytes_cafile(self, tmpfile):
1033        """
1034        `Context.load_verify_locations` accepts a file name as a `bytes`
1035        instance and uses the certificates within for verification purposes.
1036        """
1037        cafile = tmpfile + NON_ASCII.encode(getfilesystemencoding())
1038        self._load_verify_cafile(cafile)
1039
1040    def test_load_verify_unicode_cafile(self, tmpfile):
1041        """
1042        `Context.load_verify_locations` accepts a file name as a `unicode`
1043        instance and uses the certificates within for verification purposes.
1044        """
1045        self._load_verify_cafile(
1046            tmpfile.decode(getfilesystemencoding()) + NON_ASCII
1047        )
1048
1049    def test_load_verify_invalid_file(self, tmpfile):
1050        """
1051        `Context.load_verify_locations` raises `Error` when passed a
1052        non-existent cafile.
1053        """
1054        clientContext = Context(TLSv1_METHOD)
1055        with pytest.raises(Error):
1056            clientContext.load_verify_locations(tmpfile)
1057
1058    def _load_verify_directory_locations_capath(self, capath):
1059        """
1060        Verify that if path to a directory containing certificate files is
1061        passed to ``Context.load_verify_locations`` for the ``capath``
1062        parameter, those certificates are used as trust roots for the purposes
1063        of verifying connections created using that ``Context``.
1064        """
1065        makedirs(capath)
1066        # Hash values computed manually with c_rehash to avoid depending on
1067        # c_rehash in the test suite.  One is from OpenSSL 0.9.8, the other
1068        # from OpenSSL 1.0.0.
1069        for name in [b'c7adac82.0', b'c3705638.0']:
1070            cafile = join_bytes_or_unicode(capath, name)
1071            with open(cafile, 'w') as fObj:
1072                fObj.write(cleartextCertificatePEM.decode('ascii'))
1073
1074        self._load_verify_locations_test(None, capath)
1075
1076    def test_load_verify_directory_bytes_capath(self, tmpfile):
1077        """
1078        `Context.load_verify_locations` accepts a directory name as a `bytes`
1079        instance and uses the certificates within for verification purposes.
1080        """
1081        self._load_verify_directory_locations_capath(
1082            tmpfile + NON_ASCII.encode(getfilesystemencoding())
1083        )
1084
1085    def test_load_verify_directory_unicode_capath(self, tmpfile):
1086        """
1087        `Context.load_verify_locations` accepts a directory name as a `unicode`
1088        instance and uses the certificates within for verification purposes.
1089        """
1090        self._load_verify_directory_locations_capath(
1091            tmpfile.decode(getfilesystemencoding()) + NON_ASCII
1092        )
1093
1094    def test_load_verify_locations_wrong_args(self):
1095        """
1096        `Context.load_verify_locations` raises `TypeError` if with non-`str`
1097        arguments.
1098        """
1099        context = Context(TLSv1_METHOD)
1100        with pytest.raises(TypeError):
1101            context.load_verify_locations(object())
1102        with pytest.raises(TypeError):
1103            context.load_verify_locations(object(), object())
1104
1105    @pytest.mark.skipif(
1106        not platform.startswith("linux"),
1107        reason="Loading fallback paths is a linux-specific behavior to "
1108        "accommodate pyca/cryptography manylinux1 wheels"
1109    )
1110    def test_fallback_default_verify_paths(self, monkeypatch):
1111        """
1112        Test that we load certificates successfully on linux from the fallback
1113        path. To do this we set the _CRYPTOGRAPHY_MANYLINUX1_CA_FILE and
1114        _CRYPTOGRAPHY_MANYLINUX1_CA_DIR vars to be equal to whatever the
1115        current OpenSSL default is and we disable
1116        SSL_CTX_SET_default_verify_paths so that it can't find certs unless
1117        it loads via fallback.
1118        """
1119        context = Context(TLSv1_METHOD)
1120        monkeypatch.setattr(
1121            _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1
1122        )
1123        monkeypatch.setattr(
1124            SSL,
1125            "_CRYPTOGRAPHY_MANYLINUX1_CA_FILE",
1126            _ffi.string(_lib.X509_get_default_cert_file())
1127        )
1128        monkeypatch.setattr(
1129            SSL,
1130            "_CRYPTOGRAPHY_MANYLINUX1_CA_DIR",
1131            _ffi.string(_lib.X509_get_default_cert_dir())
1132        )
1133        context.set_default_verify_paths()
1134        store = context.get_cert_store()
1135        sk_obj = _lib.X509_STORE_get0_objects(store._store)
1136        assert sk_obj != _ffi.NULL
1137        num = _lib.sk_X509_OBJECT_num(sk_obj)
1138        assert num != 0
1139
1140    def test_check_env_vars(self, monkeypatch):
1141        """
1142        Test that we return True/False appropriately if the env vars are set.
1143        """
1144        context = Context(TLSv1_METHOD)
1145        dir_var = "CUSTOM_DIR_VAR"
1146        file_var = "CUSTOM_FILE_VAR"
1147        assert context._check_env_vars_set(dir_var, file_var) is False
1148        monkeypatch.setenv(dir_var, "value")
1149        monkeypatch.setenv(file_var, "value")
1150        assert context._check_env_vars_set(dir_var, file_var) is True
1151        assert context._check_env_vars_set(dir_var, file_var) is True
1152
1153    def test_verify_no_fallback_if_env_vars_set(self, monkeypatch):
1154        """
1155        Test that we don't use the fallback path if env vars are set.
1156        """
1157        context = Context(TLSv1_METHOD)
1158        monkeypatch.setattr(
1159            _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1
1160        )
1161        dir_env_var = _ffi.string(
1162            _lib.X509_get_default_cert_dir_env()
1163        ).decode("ascii")
1164        file_env_var = _ffi.string(
1165            _lib.X509_get_default_cert_file_env()
1166        ).decode("ascii")
1167        monkeypatch.setenv(dir_env_var, "value")
1168        monkeypatch.setenv(file_env_var, "value")
1169        context.set_default_verify_paths()
1170
1171        monkeypatch.setattr(
1172            context,
1173            "_fallback_default_verify_paths",
1174            raiser(SystemError)
1175        )
1176        context.set_default_verify_paths()
1177
1178    @pytest.mark.skipif(
1179        platform == "win32",
1180        reason="set_default_verify_paths appears not to work on Windows.  "
1181        "See LP#404343 and LP#404344."
1182    )
1183    def test_set_default_verify_paths(self):
1184        """
1185        `Context.set_default_verify_paths` causes the platform-specific CA
1186        certificate locations to be used for verification purposes.
1187        """
1188        # Testing this requires a server with a certificate signed by one
1189        # of the CAs in the platform CA location.  Getting one of those
1190        # costs money.  Fortunately (or unfortunately, depending on your
1191        # perspective), it's easy to think of a public server on the
1192        # internet which has such a certificate.  Connecting to the network
1193        # in a unit test is bad, but it's the only way I can think of to
1194        # really test this. -exarkun
1195        context = Context(SSLv23_METHOD)
1196        context.set_default_verify_paths()
1197        context.set_verify(
1198            VERIFY_PEER,
1199            lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
1200
1201        client = socket()
1202        client.connect(("encrypted.google.com", 443))
1203        clientSSL = Connection(context, client)
1204        clientSSL.set_connect_state()
1205        clientSSL.set_tlsext_host_name(b"encrypted.google.com")
1206        clientSSL.do_handshake()
1207        clientSSL.send(b"GET / HTTP/1.0\r\n\r\n")
1208        assert clientSSL.recv(1024)
1209
1210    def test_fallback_path_is_not_file_or_dir(self):
1211        """
1212        Test that when passed empty arrays or paths that do not exist no
1213        errors are raised.
1214        """
1215        context = Context(TLSv1_METHOD)
1216        context._fallback_default_verify_paths([], [])
1217        context._fallback_default_verify_paths(
1218            ["/not/a/file"], ["/not/a/dir"]
1219        )
1220
1221    def test_add_extra_chain_cert_invalid_cert(self):
1222        """
1223        `Context.add_extra_chain_cert` raises `TypeError` if called with an
1224        object which is not an instance of `X509`.
1225        """
1226        context = Context(TLSv1_METHOD)
1227        with pytest.raises(TypeError):
1228            context.add_extra_chain_cert(object())
1229
1230    def _handshake_test(self, serverContext, clientContext):
1231        """
1232        Verify that a client and server created with the given contexts can
1233        successfully handshake and communicate.
1234        """
1235        serverSocket, clientSocket = socket_pair()
1236
1237        server = Connection(serverContext, serverSocket)
1238        server.set_accept_state()
1239
1240        client = Connection(clientContext, clientSocket)
1241        client.set_connect_state()
1242
1243        # Make them talk to each other.
1244        # interact_in_memory(client, server)
1245        for _ in range(3):
1246            for s in [client, server]:
1247                try:
1248                    s.do_handshake()
1249                except WantReadError:
1250                    pass
1251
1252    def test_set_verify_callback_connection_argument(self):
1253        """
1254        The first argument passed to the verify callback is the
1255        `Connection` instance for which verification is taking place.
1256        """
1257        serverContext = Context(TLSv1_METHOD)
1258        serverContext.use_privatekey(
1259            load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
1260        serverContext.use_certificate(
1261            load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
1262        serverConnection = Connection(serverContext, None)
1263
1264        class VerifyCallback(object):
1265            def callback(self, connection, *args):
1266                self.connection = connection
1267                return 1
1268
1269        verify = VerifyCallback()
1270        clientContext = Context(TLSv1_METHOD)
1271        clientContext.set_verify(VERIFY_PEER, verify.callback)
1272        clientConnection = Connection(clientContext, None)
1273        clientConnection.set_connect_state()
1274
1275        handshake_in_memory(clientConnection, serverConnection)
1276
1277        assert verify.connection is clientConnection
1278
1279    def test_x509_in_verify_works(self):
1280        """
1281        We had a bug where the X509 cert instantiated in the callback wrapper
1282        didn't __init__ so it was missing objects needed when calling
1283        get_subject. This test sets up a handshake where we call get_subject
1284        on the cert provided to the verify callback.
1285        """
1286        serverContext = Context(TLSv1_METHOD)
1287        serverContext.use_privatekey(
1288            load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
1289        serverContext.use_certificate(
1290            load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
1291        serverConnection = Connection(serverContext, None)
1292
1293        def verify_cb_get_subject(conn, cert, errnum, depth, ok):
1294            assert cert.get_subject()
1295            return 1
1296
1297        clientContext = Context(TLSv1_METHOD)
1298        clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject)
1299        clientConnection = Connection(clientContext, None)
1300        clientConnection.set_connect_state()
1301
1302        handshake_in_memory(clientConnection, serverConnection)
1303
1304    def test_set_verify_callback_exception(self):
1305        """
1306        If the verify callback passed to `Context.set_verify` raises an
1307        exception, verification fails and the exception is propagated to the
1308        caller of `Connection.do_handshake`.
1309        """
1310        serverContext = Context(TLSv1_2_METHOD)
1311        serverContext.use_privatekey(
1312            load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
1313        serverContext.use_certificate(
1314            load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
1315
1316        clientContext = Context(TLSv1_2_METHOD)
1317
1318        def verify_callback(*args):
1319            raise Exception("silly verify failure")
1320        clientContext.set_verify(VERIFY_PEER, verify_callback)
1321
1322        with pytest.raises(Exception) as exc:
1323            self._handshake_test(serverContext, clientContext)
1324
1325        assert "silly verify failure" == str(exc.value)
1326
1327    def test_add_extra_chain_cert(self, tmpdir):
1328        """
1329        `Context.add_extra_chain_cert` accepts an `X509`
1330        instance to add to the certificate chain.
1331
1332        See `_create_certificate_chain` for the details of the
1333        certificate chain tested.
1334
1335        The chain is tested by starting a server with scert and connecting
1336        to it with a client which trusts cacert and requires verification to
1337        succeed.
1338        """
1339        chain = _create_certificate_chain()
1340        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
1341
1342        # Dump the CA certificate to a file because that's the only way to load
1343        # it as a trusted CA in the client context.
1344        for cert, name in [(cacert, 'ca.pem'),
1345                           (icert, 'i.pem'),
1346                           (scert, 's.pem')]:
1347            with tmpdir.join(name).open('w') as f:
1348                f.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii'))
1349
1350        for key, name in [(cakey, 'ca.key'),
1351                          (ikey, 'i.key'),
1352                          (skey, 's.key')]:
1353            with tmpdir.join(name).open('w') as f:
1354                f.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii'))
1355
1356        # Create the server context
1357        serverContext = Context(TLSv1_METHOD)
1358        serverContext.use_privatekey(skey)
1359        serverContext.use_certificate(scert)
1360        # The client already has cacert, we only need to give them icert.
1361        serverContext.add_extra_chain_cert(icert)
1362
1363        # Create the client
1364        clientContext = Context(TLSv1_METHOD)
1365        clientContext.set_verify(
1366            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
1367        clientContext.load_verify_locations(str(tmpdir.join("ca.pem")))
1368
1369        # Try it out.
1370        self._handshake_test(serverContext, clientContext)
1371
1372    def _use_certificate_chain_file_test(self, certdir):
1373        """
1374        Verify that `Context.use_certificate_chain_file` reads a
1375        certificate chain from a specified file.
1376
1377        The chain is tested by starting a server with scert and connecting to
1378        it with a client which trusts cacert and requires verification to
1379        succeed.
1380        """
1381        chain = _create_certificate_chain()
1382        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
1383
1384        makedirs(certdir)
1385
1386        chainFile = join_bytes_or_unicode(certdir, "chain.pem")
1387        caFile = join_bytes_or_unicode(certdir, "ca.pem")
1388
1389        # Write out the chain file.
1390        with open(chainFile, 'wb') as fObj:
1391            # Most specific to least general.
1392            fObj.write(dump_certificate(FILETYPE_PEM, scert))
1393            fObj.write(dump_certificate(FILETYPE_PEM, icert))
1394            fObj.write(dump_certificate(FILETYPE_PEM, cacert))
1395
1396        with open(caFile, 'w') as fObj:
1397            fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii'))
1398
1399        serverContext = Context(TLSv1_METHOD)
1400        serverContext.use_certificate_chain_file(chainFile)
1401        serverContext.use_privatekey(skey)
1402
1403        clientContext = Context(TLSv1_METHOD)
1404        clientContext.set_verify(
1405            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
1406        clientContext.load_verify_locations(caFile)
1407
1408        self._handshake_test(serverContext, clientContext)
1409
1410    def test_use_certificate_chain_file_bytes(self, tmpfile):
1411        """
1412        ``Context.use_certificate_chain_file`` accepts the name of a file (as
1413        an instance of ``bytes``) to specify additional certificates to use to
1414        construct and verify a trust chain.
1415        """
1416        self._use_certificate_chain_file_test(
1417            tmpfile + NON_ASCII.encode(getfilesystemencoding())
1418        )
1419
1420    def test_use_certificate_chain_file_unicode(self, tmpfile):
1421        """
1422        ``Context.use_certificate_chain_file`` accepts the name of a file (as
1423        an instance of ``unicode``) to specify additional certificates to use
1424        to construct and verify a trust chain.
1425        """
1426        self._use_certificate_chain_file_test(
1427            tmpfile.decode(getfilesystemencoding()) + NON_ASCII
1428        )
1429
1430    def test_use_certificate_chain_file_wrong_args(self):
1431        """
1432        `Context.use_certificate_chain_file` raises `TypeError` if passed a
1433        non-byte string single argument.
1434        """
1435        context = Context(TLSv1_METHOD)
1436        with pytest.raises(TypeError):
1437            context.use_certificate_chain_file(object())
1438
1439    def test_use_certificate_chain_file_missing_file(self, tmpfile):
1440        """
1441        `Context.use_certificate_chain_file` raises `OpenSSL.SSL.Error` when
1442        passed a bad chain file name (for example, the name of a file which
1443        does not exist).
1444        """
1445        context = Context(TLSv1_METHOD)
1446        with pytest.raises(Error):
1447            context.use_certificate_chain_file(tmpfile)
1448
1449    def test_set_verify_mode(self):
1450        """
1451        `Context.get_verify_mode` returns the verify mode flags previously
1452        passed to `Context.set_verify`.
1453        """
1454        context = Context(TLSv1_METHOD)
1455        assert context.get_verify_mode() == 0
1456        context.set_verify(
1457            VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None)
1458        assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE)
1459
1460    @skip_if_py3
1461    def test_set_verify_mode_long(self):
1462        """
1463        On Python 2 `Context.set_verify_mode` accepts values of type `long`
1464        as well as `int`.
1465        """
1466        context = Context(TLSv1_METHOD)
1467        assert context.get_verify_mode() == 0
1468        context.set_verify(
1469            long(VERIFY_PEER | VERIFY_CLIENT_ONCE), lambda *args: None
1470        )  # pragma: nocover
1471        assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE)
1472
1473    @pytest.mark.parametrize('mode', [None, 1.0, object(), 'mode'])
1474    def test_set_verify_wrong_mode_arg(self, mode):
1475        """
1476        `Context.set_verify` raises `TypeError` if the first argument is
1477        not an integer.
1478        """
1479        context = Context(TLSv1_METHOD)
1480        with pytest.raises(TypeError):
1481            context.set_verify(mode=mode, callback=lambda *args: None)
1482
1483    @pytest.mark.parametrize('callback', [None, 1.0, 'mode', ('foo', 'bar')])
1484    def test_set_verify_wrong_callable_arg(self, callback):
1485        """
1486        `Context.set_verify` raises `TypeError` if the second argument
1487        is not callable.
1488        """
1489        context = Context(TLSv1_METHOD)
1490        with pytest.raises(TypeError):
1491            context.set_verify(mode=VERIFY_PEER, callback=callback)
1492
1493    def test_load_tmp_dh_wrong_args(self):
1494        """
1495        `Context.load_tmp_dh` raises `TypeError` if called with a
1496        non-`str` argument.
1497        """
1498        context = Context(TLSv1_METHOD)
1499        with pytest.raises(TypeError):
1500            context.load_tmp_dh(object())
1501
1502    def test_load_tmp_dh_missing_file(self):
1503        """
1504        `Context.load_tmp_dh` raises `OpenSSL.SSL.Error` if the
1505        specified file does not exist.
1506        """
1507        context = Context(TLSv1_METHOD)
1508        with pytest.raises(Error):
1509            context.load_tmp_dh(b"hello")
1510
1511    def _load_tmp_dh_test(self, dhfilename):
1512        """
1513        Verify that calling ``Context.load_tmp_dh`` with the given filename
1514        does not raise an exception.
1515        """
1516        context = Context(TLSv1_METHOD)
1517        with open(dhfilename, "w") as dhfile:
1518            dhfile.write(dhparam)
1519
1520        context.load_tmp_dh(dhfilename)
1521        # XXX What should I assert here? -exarkun
1522
1523    def test_load_tmp_dh_bytes(self, tmpfile):
1524        """
1525        `Context.load_tmp_dh` loads Diffie-Hellman parameters from the
1526        specified file (given as ``bytes``).
1527        """
1528        self._load_tmp_dh_test(
1529            tmpfile + NON_ASCII.encode(getfilesystemencoding()),
1530        )
1531
1532    def test_load_tmp_dh_unicode(self, tmpfile):
1533        """
1534        `Context.load_tmp_dh` loads Diffie-Hellman parameters from the
1535        specified file (given as ``unicode``).
1536        """
1537        self._load_tmp_dh_test(
1538            tmpfile.decode(getfilesystemencoding()) + NON_ASCII,
1539        )
1540
1541    def test_set_tmp_ecdh(self):
1542        """
1543        `Context.set_tmp_ecdh` sets the elliptic curve for Diffie-Hellman to
1544        the specified curve.
1545        """
1546        context = Context(TLSv1_METHOD)
1547        for curve in get_elliptic_curves():
1548            if curve.name.startswith(u"Oakley-"):
1549                # Setting Oakley-EC2N-4 and Oakley-EC2N-3 adds
1550                # ('bignum routines', 'BN_mod_inverse', 'no inverse') to the
1551                # error queue on OpenSSL 1.0.2.
1552                continue
1553            # The only easily "assertable" thing is that it does not raise an
1554            # exception.
1555            context.set_tmp_ecdh(curve)
1556
1557    def test_set_session_cache_mode_wrong_args(self):
1558        """
1559        `Context.set_session_cache_mode` raises `TypeError` if called with
1560        a non-integer argument.
1561        called with other than one integer argument.
1562        """
1563        context = Context(TLSv1_METHOD)
1564        with pytest.raises(TypeError):
1565            context.set_session_cache_mode(object())
1566
1567    def test_session_cache_mode(self):
1568        """
1569        `Context.set_session_cache_mode` specifies how sessions are cached.
1570        The setting can be retrieved via `Context.get_session_cache_mode`.
1571        """
1572        context = Context(TLSv1_METHOD)
1573        context.set_session_cache_mode(SESS_CACHE_OFF)
1574        off = context.set_session_cache_mode(SESS_CACHE_BOTH)
1575        assert SESS_CACHE_OFF == off
1576        assert SESS_CACHE_BOTH == context.get_session_cache_mode()
1577
1578    @skip_if_py3
1579    def test_session_cache_mode_long(self):
1580        """
1581        On Python 2 `Context.set_session_cache_mode` accepts values
1582        of type `long` as well as `int`.
1583        """
1584        context = Context(TLSv1_METHOD)
1585        context.set_session_cache_mode(long(SESS_CACHE_BOTH))
1586        assert SESS_CACHE_BOTH == context.get_session_cache_mode()
1587
1588    def test_get_cert_store(self):
1589        """
1590        `Context.get_cert_store` returns a `X509Store` instance.
1591        """
1592        context = Context(TLSv1_METHOD)
1593        store = context.get_cert_store()
1594        assert isinstance(store, X509Store)
1595
1596    def test_set_tlsext_use_srtp_not_bytes(self):
1597        """
1598        `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material.
1599
1600        It raises a TypeError if the list of profiles is not a byte string.
1601        """
1602        context = Context(TLSv1_METHOD)
1603        with pytest.raises(TypeError):
1604            context.set_tlsext_use_srtp(text_type('SRTP_AES128_CM_SHA1_80'))
1605
1606    def test_set_tlsext_use_srtp_invalid_profile(self):
1607        """
1608        `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material.
1609
1610        It raises an Error if the call to OpenSSL fails.
1611        """
1612        context = Context(TLSv1_METHOD)
1613        with pytest.raises(Error):
1614            context.set_tlsext_use_srtp(b'SRTP_BOGUS')
1615
1616    def test_set_tlsext_use_srtp_valid(self):
1617        """
1618        `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material.
1619
1620        It does not return anything.
1621        """
1622        context = Context(TLSv1_METHOD)
1623        assert context.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80') is None
1624
1625
1626class TestServerNameCallback(object):
1627    """
1628    Tests for `Context.set_tlsext_servername_callback` and its
1629    interaction with `Connection`.
1630    """
1631    def test_old_callback_forgotten(self):
1632        """
1633        If `Context.set_tlsext_servername_callback` is used to specify
1634        a new callback, the one it replaces is dereferenced.
1635        """
1636        def callback(connection):  # pragma: no cover
1637            pass
1638
1639        def replacement(connection):  # pragma: no cover
1640            pass
1641
1642        context = Context(TLSv1_METHOD)
1643        context.set_tlsext_servername_callback(callback)
1644
1645        tracker = ref(callback)
1646        del callback
1647
1648        context.set_tlsext_servername_callback(replacement)
1649
1650        # One run of the garbage collector happens to work on CPython.  PyPy
1651        # doesn't collect the underlying object until a second run for whatever
1652        # reason.  That's fine, it still demonstrates our code has properly
1653        # dropped the reference.
1654        collect()
1655        collect()
1656
1657        callback = tracker()
1658        if callback is not None:
1659            referrers = get_referrers(callback)
1660            if len(referrers) > 1:  # pragma: nocover
1661                pytest.fail("Some references remain: %r" % (referrers,))
1662
1663    def test_no_servername(self):
1664        """
1665        When a client specifies no server name, the callback passed to
1666        `Context.set_tlsext_servername_callback` is invoked and the
1667        result of `Connection.get_servername` is `None`.
1668        """
1669        args = []
1670
1671        def servername(conn):
1672            args.append((conn, conn.get_servername()))
1673        context = Context(TLSv1_METHOD)
1674        context.set_tlsext_servername_callback(servername)
1675
1676        # Lose our reference to it.  The Context is responsible for keeping it
1677        # alive now.
1678        del servername
1679        collect()
1680
1681        # Necessary to actually accept the connection
1682        context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
1683        context.use_certificate(
1684            load_certificate(FILETYPE_PEM, server_cert_pem))
1685
1686        # Do a little connection to trigger the logic
1687        server = Connection(context, None)
1688        server.set_accept_state()
1689
1690        client = Connection(Context(TLSv1_METHOD), None)
1691        client.set_connect_state()
1692
1693        interact_in_memory(server, client)
1694
1695        assert args == [(server, None)]
1696
1697    def test_servername(self):
1698        """
1699        When a client specifies a server name in its hello message, the
1700        callback passed to `Contexts.set_tlsext_servername_callback` is
1701        invoked and the result of `Connection.get_servername` is that
1702        server name.
1703        """
1704        args = []
1705
1706        def servername(conn):
1707            args.append((conn, conn.get_servername()))
1708        context = Context(TLSv1_METHOD)
1709        context.set_tlsext_servername_callback(servername)
1710
1711        # Necessary to actually accept the connection
1712        context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
1713        context.use_certificate(
1714            load_certificate(FILETYPE_PEM, server_cert_pem))
1715
1716        # Do a little connection to trigger the logic
1717        server = Connection(context, None)
1718        server.set_accept_state()
1719
1720        client = Connection(Context(TLSv1_METHOD), None)
1721        client.set_connect_state()
1722        client.set_tlsext_host_name(b"foo1.example.com")
1723
1724        interact_in_memory(server, client)
1725
1726        assert args == [(server, b"foo1.example.com")]
1727
1728
1729class TestNextProtoNegotiation(object):
1730    """
1731    Test for Next Protocol Negotiation in PyOpenSSL.
1732    """
1733    def test_npn_success(self):
1734        """
1735        Tests that clients and servers that agree on the negotiated next
1736        protocol can correct establish a connection, and that the agreed
1737        protocol is reported by the connections.
1738        """
1739        advertise_args = []
1740        select_args = []
1741
1742        def advertise(conn):
1743            advertise_args.append((conn,))
1744            return [b'http/1.1', b'spdy/2']
1745
1746        def select(conn, options):
1747            select_args.append((conn, options))
1748            return b'spdy/2'
1749
1750        server_context = Context(TLSv1_METHOD)
1751        server_context.set_npn_advertise_callback(advertise)
1752
1753        client_context = Context(TLSv1_METHOD)
1754        client_context.set_npn_select_callback(select)
1755
1756        # Necessary to actually accept the connection
1757        server_context.use_privatekey(
1758            load_privatekey(FILETYPE_PEM, server_key_pem))
1759        server_context.use_certificate(
1760            load_certificate(FILETYPE_PEM, server_cert_pem))
1761
1762        # Do a little connection to trigger the logic
1763        server = Connection(server_context, None)
1764        server.set_accept_state()
1765
1766        client = Connection(client_context, None)
1767        client.set_connect_state()
1768
1769        interact_in_memory(server, client)
1770
1771        assert advertise_args == [(server,)]
1772        assert select_args == [(client, [b'http/1.1', b'spdy/2'])]
1773
1774        assert server.get_next_proto_negotiated() == b'spdy/2'
1775        assert client.get_next_proto_negotiated() == b'spdy/2'
1776
1777    def test_npn_client_fail(self):
1778        """
1779        Tests that when clients and servers cannot agree on what protocol
1780        to use next that the TLS connection does not get established.
1781        """
1782        advertise_args = []
1783        select_args = []
1784
1785        def advertise(conn):
1786            advertise_args.append((conn,))
1787            return [b'http/1.1', b'spdy/2']
1788
1789        def select(conn, options):
1790            select_args.append((conn, options))
1791            return b''
1792
1793        server_context = Context(TLSv1_METHOD)
1794        server_context.set_npn_advertise_callback(advertise)
1795
1796        client_context = Context(TLSv1_METHOD)
1797        client_context.set_npn_select_callback(select)
1798
1799        # Necessary to actually accept the connection
1800        server_context.use_privatekey(
1801            load_privatekey(FILETYPE_PEM, server_key_pem))
1802        server_context.use_certificate(
1803            load_certificate(FILETYPE_PEM, server_cert_pem))
1804
1805        # Do a little connection to trigger the logic
1806        server = Connection(server_context, None)
1807        server.set_accept_state()
1808
1809        client = Connection(client_context, None)
1810        client.set_connect_state()
1811
1812        # If the client doesn't return anything, the connection will fail.
1813        with pytest.raises(Error):
1814            interact_in_memory(server, client)
1815
1816        assert advertise_args == [(server,)]
1817        assert select_args == [(client, [b'http/1.1', b'spdy/2'])]
1818
1819    def test_npn_select_error(self):
1820        """
1821        Test that we can handle exceptions in the select callback. If
1822        select fails it should be fatal to the connection.
1823        """
1824        advertise_args = []
1825
1826        def advertise(conn):
1827            advertise_args.append((conn,))
1828            return [b'http/1.1', b'spdy/2']
1829
1830        def select(conn, options):
1831            raise TypeError
1832
1833        server_context = Context(TLSv1_METHOD)
1834        server_context.set_npn_advertise_callback(advertise)
1835
1836        client_context = Context(TLSv1_METHOD)
1837        client_context.set_npn_select_callback(select)
1838
1839        # Necessary to actually accept the connection
1840        server_context.use_privatekey(
1841            load_privatekey(FILETYPE_PEM, server_key_pem))
1842        server_context.use_certificate(
1843            load_certificate(FILETYPE_PEM, server_cert_pem))
1844
1845        # Do a little connection to trigger the logic
1846        server = Connection(server_context, None)
1847        server.set_accept_state()
1848
1849        client = Connection(client_context, None)
1850        client.set_connect_state()
1851
1852        # If the callback throws an exception it should be raised here.
1853        with pytest.raises(TypeError):
1854            interact_in_memory(server, client)
1855        assert advertise_args == [(server,), ]
1856
1857    def test_npn_advertise_error(self):
1858        """
1859        Test that we can handle exceptions in the advertise callback. If
1860        advertise fails no NPN is advertised to the client.
1861        """
1862        select_args = []
1863
1864        def advertise(conn):
1865            raise TypeError
1866
1867        def select(conn, options):  # pragma: nocover
1868            """
1869            Assert later that no args are actually appended.
1870            """
1871            select_args.append((conn, options))
1872            return b''
1873
1874        server_context = Context(TLSv1_METHOD)
1875        server_context.set_npn_advertise_callback(advertise)
1876
1877        client_context = Context(TLSv1_METHOD)
1878        client_context.set_npn_select_callback(select)
1879
1880        # Necessary to actually accept the connection
1881        server_context.use_privatekey(
1882            load_privatekey(FILETYPE_PEM, server_key_pem))
1883        server_context.use_certificate(
1884            load_certificate(FILETYPE_PEM, server_cert_pem))
1885
1886        # Do a little connection to trigger the logic
1887        server = Connection(server_context, None)
1888        server.set_accept_state()
1889
1890        client = Connection(client_context, None)
1891        client.set_connect_state()
1892
1893        # If the client doesn't return anything, the connection will fail.
1894        with pytest.raises(TypeError):
1895            interact_in_memory(server, client)
1896        assert select_args == []
1897
1898
1899class TestApplicationLayerProtoNegotiation(object):
1900    """
1901    Tests for ALPN in PyOpenSSL.
1902    """
1903    # Skip tests on versions that don't support ALPN.
1904    if _lib.Cryptography_HAS_ALPN:
1905
1906        def test_alpn_success(self):
1907            """
1908            Clients and servers that agree on the negotiated ALPN protocol can
1909            correct establish a connection, and the agreed protocol is reported
1910            by the connections.
1911            """
1912            select_args = []
1913
1914            def select(conn, options):
1915                select_args.append((conn, options))
1916                return b'spdy/2'
1917
1918            client_context = Context(TLSv1_METHOD)
1919            client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
1920
1921            server_context = Context(TLSv1_METHOD)
1922            server_context.set_alpn_select_callback(select)
1923
1924            # Necessary to actually accept the connection
1925            server_context.use_privatekey(
1926                load_privatekey(FILETYPE_PEM, server_key_pem))
1927            server_context.use_certificate(
1928                load_certificate(FILETYPE_PEM, server_cert_pem))
1929
1930            # Do a little connection to trigger the logic
1931            server = Connection(server_context, None)
1932            server.set_accept_state()
1933
1934            client = Connection(client_context, None)
1935            client.set_connect_state()
1936
1937            interact_in_memory(server, client)
1938
1939            assert select_args == [(server, [b'http/1.1', b'spdy/2'])]
1940
1941            assert server.get_alpn_proto_negotiated() == b'spdy/2'
1942            assert client.get_alpn_proto_negotiated() == b'spdy/2'
1943
1944        def test_alpn_set_on_connection(self):
1945            """
1946            The same as test_alpn_success, but setting the ALPN protocols on
1947            the connection rather than the context.
1948            """
1949            select_args = []
1950
1951            def select(conn, options):
1952                select_args.append((conn, options))
1953                return b'spdy/2'
1954
1955            # Setup the client context but don't set any ALPN protocols.
1956            client_context = Context(TLSv1_METHOD)
1957
1958            server_context = Context(TLSv1_METHOD)
1959            server_context.set_alpn_select_callback(select)
1960
1961            # Necessary to actually accept the connection
1962            server_context.use_privatekey(
1963                load_privatekey(FILETYPE_PEM, server_key_pem))
1964            server_context.use_certificate(
1965                load_certificate(FILETYPE_PEM, server_cert_pem))
1966
1967            # Do a little connection to trigger the logic
1968            server = Connection(server_context, None)
1969            server.set_accept_state()
1970
1971            # Set the ALPN protocols on the client connection.
1972            client = Connection(client_context, None)
1973            client.set_alpn_protos([b'http/1.1', b'spdy/2'])
1974            client.set_connect_state()
1975
1976            interact_in_memory(server, client)
1977
1978            assert select_args == [(server, [b'http/1.1', b'spdy/2'])]
1979
1980            assert server.get_alpn_proto_negotiated() == b'spdy/2'
1981            assert client.get_alpn_proto_negotiated() == b'spdy/2'
1982
1983        def test_alpn_server_fail(self):
1984            """
1985            When clients and servers cannot agree on what protocol to use next
1986            the TLS connection does not get established.
1987            """
1988            select_args = []
1989
1990            def select(conn, options):
1991                select_args.append((conn, options))
1992                return b''
1993
1994            client_context = Context(TLSv1_METHOD)
1995            client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
1996
1997            server_context = Context(TLSv1_METHOD)
1998            server_context.set_alpn_select_callback(select)
1999
2000            # Necessary to actually accept the connection
2001            server_context.use_privatekey(
2002                load_privatekey(FILETYPE_PEM, server_key_pem))
2003            server_context.use_certificate(
2004                load_certificate(FILETYPE_PEM, server_cert_pem))
2005
2006            # Do a little connection to trigger the logic
2007            server = Connection(server_context, None)
2008            server.set_accept_state()
2009
2010            client = Connection(client_context, None)
2011            client.set_connect_state()
2012
2013            # If the client doesn't return anything, the connection will fail.
2014            with pytest.raises(Error):
2015                interact_in_memory(server, client)
2016
2017            assert select_args == [(server, [b'http/1.1', b'spdy/2'])]
2018
2019        def test_alpn_no_server(self):
2020            """
2021            When clients and servers cannot agree on what protocol to use next
2022            because the server doesn't offer ALPN, no protocol is negotiated.
2023            """
2024            client_context = Context(TLSv1_METHOD)
2025            client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
2026
2027            server_context = Context(TLSv1_METHOD)
2028
2029            # Necessary to actually accept the connection
2030            server_context.use_privatekey(
2031                load_privatekey(FILETYPE_PEM, server_key_pem))
2032            server_context.use_certificate(
2033                load_certificate(FILETYPE_PEM, server_cert_pem))
2034
2035            # Do a little connection to trigger the logic
2036            server = Connection(server_context, None)
2037            server.set_accept_state()
2038
2039            client = Connection(client_context, None)
2040            client.set_connect_state()
2041
2042            # Do the dance.
2043            interact_in_memory(server, client)
2044
2045            assert client.get_alpn_proto_negotiated() == b''
2046
2047        def test_alpn_callback_exception(self):
2048            """
2049            We can handle exceptions in the ALPN select callback.
2050            """
2051            select_args = []
2052
2053            def select(conn, options):
2054                select_args.append((conn, options))
2055                raise TypeError()
2056
2057            client_context = Context(TLSv1_METHOD)
2058            client_context.set_alpn_protos([b'http/1.1', b'spdy/2'])
2059
2060            server_context = Context(TLSv1_METHOD)
2061            server_context.set_alpn_select_callback(select)
2062
2063            # Necessary to actually accept the connection
2064            server_context.use_privatekey(
2065                load_privatekey(FILETYPE_PEM, server_key_pem))
2066            server_context.use_certificate(
2067                load_certificate(FILETYPE_PEM, server_cert_pem))
2068
2069            # Do a little connection to trigger the logic
2070            server = Connection(server_context, None)
2071            server.set_accept_state()
2072
2073            client = Connection(client_context, None)
2074            client.set_connect_state()
2075
2076            with pytest.raises(TypeError):
2077                interact_in_memory(server, client)
2078            assert select_args == [(server, [b'http/1.1', b'spdy/2'])]
2079
2080    else:
2081        # No ALPN.
2082        def test_alpn_not_implemented(self):
2083            """
2084            If ALPN is not in OpenSSL, we should raise NotImplementedError.
2085            """
2086            # Test the context methods first.
2087            context = Context(TLSv1_METHOD)
2088            with pytest.raises(NotImplementedError):
2089                context.set_alpn_protos(None)
2090            with pytest.raises(NotImplementedError):
2091                context.set_alpn_select_callback(None)
2092
2093            # Now test a connection.
2094            conn = Connection(context)
2095            with pytest.raises(NotImplementedError):
2096                conn.set_alpn_protos(None)
2097
2098
2099class TestSession(object):
2100    """
2101    Unit tests for :py:obj:`OpenSSL.SSL.Session`.
2102    """
2103    def test_construction(self):
2104        """
2105        :py:class:`Session` can be constructed with no arguments, creating
2106        a new instance of that type.
2107        """
2108        new_session = Session()
2109        assert isinstance(new_session, Session)
2110
2111
2112class TestConnection(object):
2113    """
2114    Unit tests for `OpenSSL.SSL.Connection`.
2115    """
2116    # XXX get_peer_certificate -> None
2117    # XXX sock_shutdown
2118    # XXX master_key -> TypeError
2119    # XXX server_random -> TypeError
2120    # XXX connect -> TypeError
2121    # XXX connect_ex -> TypeError
2122    # XXX set_connect_state -> TypeError
2123    # XXX set_accept_state -> TypeError
2124    # XXX do_handshake -> TypeError
2125    # XXX bio_read -> TypeError
2126    # XXX recv -> TypeError
2127    # XXX send -> TypeError
2128    # XXX bio_write -> TypeError
2129
2130    def test_type(self):
2131        """
2132        `Connection` and `ConnectionType` refer to the same type object and
2133        can be used to create instances of that type.
2134        """
2135        assert Connection is ConnectionType
2136        ctx = Context(TLSv1_METHOD)
2137        assert is_consistent_type(Connection, 'Connection', ctx, None)
2138
2139    @pytest.mark.parametrize('bad_context', [object(), 'context', None, 1])
2140    def test_wrong_args(self, bad_context):
2141        """
2142        `Connection.__init__` raises `TypeError` if called with a non-`Context`
2143        instance argument.
2144        """
2145        with pytest.raises(TypeError):
2146            Connection(bad_context)
2147
2148    def test_get_context(self):
2149        """
2150        `Connection.get_context` returns the `Context` instance used to
2151        construct the `Connection` instance.
2152        """
2153        context = Context(TLSv1_METHOD)
2154        connection = Connection(context, None)
2155        assert connection.get_context() is context
2156
2157    def test_set_context_wrong_args(self):
2158        """
2159        `Connection.set_context` raises `TypeError` if called with a
2160        non-`Context` instance argument.
2161        """
2162        ctx = Context(TLSv1_METHOD)
2163        connection = Connection(ctx, None)
2164        with pytest.raises(TypeError):
2165            connection.set_context(object())
2166        with pytest.raises(TypeError):
2167            connection.set_context("hello")
2168        with pytest.raises(TypeError):
2169            connection.set_context(1)
2170        assert ctx is connection.get_context()
2171
2172    def test_set_context(self):
2173        """
2174        `Connection.set_context` specifies a new `Context` instance to be
2175        used for the connection.
2176        """
2177        original = Context(SSLv23_METHOD)
2178        replacement = Context(TLSv1_METHOD)
2179        connection = Connection(original, None)
2180        connection.set_context(replacement)
2181        assert replacement is connection.get_context()
2182        # Lose our references to the contexts, just in case the Connection
2183        # isn't properly managing its own contributions to their reference
2184        # counts.
2185        del original, replacement
2186        collect()
2187
2188    def test_set_tlsext_host_name_wrong_args(self):
2189        """
2190        If `Connection.set_tlsext_host_name` is called with a non-byte string
2191        argument or a byte string with an embedded NUL, `TypeError` is raised.
2192        """
2193        conn = Connection(Context(TLSv1_METHOD), None)
2194        with pytest.raises(TypeError):
2195            conn.set_tlsext_host_name(object())
2196        with pytest.raises(TypeError):
2197            conn.set_tlsext_host_name(b"with\0null")
2198
2199        if PY3:
2200            # On Python 3.x, don't accidentally implicitly convert from text.
2201            with pytest.raises(TypeError):
2202                conn.set_tlsext_host_name(b"example.com".decode("ascii"))
2203
2204    def test_pending(self):
2205        """
2206        `Connection.pending` returns the number of bytes available for
2207        immediate read.
2208        """
2209        connection = Connection(Context(TLSv1_METHOD), None)
2210        assert connection.pending() == 0
2211
2212    def test_peek(self):
2213        """
2214        `Connection.recv` peeks into the connection if `socket.MSG_PEEK` is
2215        passed.
2216        """
2217        server, client = loopback()
2218        server.send(b'xy')
2219        assert client.recv(2, MSG_PEEK) == b'xy'
2220        assert client.recv(2, MSG_PEEK) == b'xy'
2221        assert client.recv(2) == b'xy'
2222
2223    def test_connect_wrong_args(self):
2224        """
2225        `Connection.connect` raises `TypeError` if called with a non-address
2226        argument.
2227        """
2228        connection = Connection(Context(TLSv1_METHOD), socket())
2229        with pytest.raises(TypeError):
2230            connection.connect(None)
2231
2232    def test_connect_refused(self):
2233        """
2234        `Connection.connect` raises `socket.error` if the underlying socket
2235        connect method raises it.
2236        """
2237        client = socket()
2238        context = Context(TLSv1_METHOD)
2239        clientSSL = Connection(context, client)
2240        # pytest.raises here doesn't work because of a bug in py.test on Python
2241        # 2.6: https://github.com/pytest-dev/pytest/issues/988
2242        try:
2243            clientSSL.connect(("127.0.0.1", 1))
2244        except error as e:
2245            exc = e
2246        assert exc.args[0] == ECONNREFUSED
2247
2248    def test_connect(self):
2249        """
2250        `Connection.connect` establishes a connection to the specified address.
2251        """
2252        port = socket()
2253        port.bind(('', 0))
2254        port.listen(3)
2255
2256        clientSSL = Connection(Context(TLSv1_METHOD), socket())
2257        clientSSL.connect(('127.0.0.1', port.getsockname()[1]))
2258        # XXX An assertion?  Or something?
2259
2260    @pytest.mark.skipif(
2261        platform == "darwin",
2262        reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4"
2263    )
2264    def test_connect_ex(self):
2265        """
2266        If there is a connection error, `Connection.connect_ex` returns the
2267        errno instead of raising an exception.
2268        """
2269        port = socket()
2270        port.bind(('', 0))
2271        port.listen(3)
2272
2273        clientSSL = Connection(Context(TLSv1_METHOD), socket())
2274        clientSSL.setblocking(False)
2275        result = clientSSL.connect_ex(port.getsockname())
2276        expected = (EINPROGRESS, EWOULDBLOCK)
2277        assert result in expected
2278
2279    def test_accept(self):
2280        """
2281        `Connection.accept` accepts a pending connection attempt and returns a
2282        tuple of a new `Connection` (the accepted client) and the address the
2283        connection originated from.
2284        """
2285        ctx = Context(TLSv1_METHOD)
2286        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
2287        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
2288        port = socket()
2289        portSSL = Connection(ctx, port)
2290        portSSL.bind(('', 0))
2291        portSSL.listen(3)
2292
2293        clientSSL = Connection(Context(TLSv1_METHOD), socket())
2294
2295        # Calling portSSL.getsockname() here to get the server IP address
2296        # sounds great, but frequently fails on Windows.
2297        clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1]))
2298
2299        serverSSL, address = portSSL.accept()
2300
2301        assert isinstance(serverSSL, Connection)
2302        assert serverSSL.get_context() is ctx
2303        assert address == clientSSL.getsockname()
2304
2305    def test_shutdown_wrong_args(self):
2306        """
2307        `Connection.set_shutdown` raises `TypeError` if called with arguments
2308        other than integers.
2309        """
2310        connection = Connection(Context(TLSv1_METHOD), None)
2311        with pytest.raises(TypeError):
2312            connection.set_shutdown(None)
2313
2314    def test_shutdown(self):
2315        """
2316        `Connection.shutdown` performs an SSL-level connection shutdown.
2317        """
2318        server, client = loopback()
2319        assert not server.shutdown()
2320        assert server.get_shutdown() == SENT_SHUTDOWN
2321        with pytest.raises(ZeroReturnError):
2322            client.recv(1024)
2323        assert client.get_shutdown() == RECEIVED_SHUTDOWN
2324        client.shutdown()
2325        assert client.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN)
2326        with pytest.raises(ZeroReturnError):
2327            server.recv(1024)
2328        assert server.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN)
2329
2330    def test_shutdown_closed(self):
2331        """
2332        If the underlying socket is closed, `Connection.shutdown` propagates
2333        the write error from the low level write call.
2334        """
2335        server, client = loopback()
2336        server.sock_shutdown(2)
2337        with pytest.raises(SysCallError) as exc:
2338            server.shutdown()
2339        if platform == "win32":
2340            assert exc.value.args[0] == ESHUTDOWN
2341        else:
2342            assert exc.value.args[0] == EPIPE
2343
2344    def test_shutdown_truncated(self):
2345        """
2346        If the underlying connection is truncated, `Connection.shutdown`
2347        raises an `Error`.
2348        """
2349        server_ctx = Context(TLSv1_METHOD)
2350        client_ctx = Context(TLSv1_METHOD)
2351        server_ctx.use_privatekey(
2352            load_privatekey(FILETYPE_PEM, server_key_pem))
2353        server_ctx.use_certificate(
2354            load_certificate(FILETYPE_PEM, server_cert_pem))
2355        server = Connection(server_ctx, None)
2356        client = Connection(client_ctx, None)
2357        handshake_in_memory(client, server)
2358        assert not server.shutdown()
2359        with pytest.raises(WantReadError):
2360            server.shutdown()
2361        server.bio_shutdown()
2362        with pytest.raises(Error):
2363            server.shutdown()
2364
2365    def test_set_shutdown(self):
2366        """
2367        `Connection.set_shutdown` sets the state of the SSL connection
2368        shutdown process.
2369        """
2370        connection = Connection(Context(TLSv1_METHOD), socket())
2371        connection.set_shutdown(RECEIVED_SHUTDOWN)
2372        assert connection.get_shutdown() == RECEIVED_SHUTDOWN
2373
2374    @skip_if_py3
2375    def test_set_shutdown_long(self):
2376        """
2377        On Python 2 `Connection.set_shutdown` accepts an argument
2378        of type `long` as well as `int`.
2379        """
2380        connection = Connection(Context(TLSv1_METHOD), socket())
2381        connection.set_shutdown(long(RECEIVED_SHUTDOWN))
2382        assert connection.get_shutdown() == RECEIVED_SHUTDOWN
2383
2384    def test_state_string(self):
2385        """
2386        `Connection.state_string` verbosely describes the current state of
2387        the `Connection`.
2388        """
2389        server, client = socket_pair()
2390        server = loopback_server_factory(server)
2391        client = loopback_client_factory(client)
2392
2393        assert server.get_state_string() in [
2394            b"before/accept initialization", b"before SSL initialization"
2395        ]
2396        assert client.get_state_string() in [
2397            b"before/connect initialization", b"before SSL initialization"
2398        ]
2399
2400    def test_app_data(self):
2401        """
2402        Any object can be set as app data by passing it to
2403        `Connection.set_app_data` and later retrieved with
2404        `Connection.get_app_data`.
2405        """
2406        conn = Connection(Context(TLSv1_METHOD), None)
2407        assert None is conn.get_app_data()
2408        app_data = object()
2409        conn.set_app_data(app_data)
2410        assert conn.get_app_data() is app_data
2411
2412    def test_makefile(self):
2413        """
2414        `Connection.makefile` is not implemented and calling that
2415        method raises `NotImplementedError`.
2416        """
2417        conn = Connection(Context(TLSv1_METHOD), None)
2418        with pytest.raises(NotImplementedError):
2419            conn.makefile()
2420
2421    def test_get_certificate(self):
2422        """
2423        `Connection.get_certificate` returns the local certificate.
2424        """
2425        chain = _create_certificate_chain()
2426        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
2427
2428        context = Context(TLSv1_METHOD)
2429        context.use_certificate(scert)
2430        client = Connection(context, None)
2431        cert = client.get_certificate()
2432        assert cert is not None
2433        assert "Server Certificate" == cert.get_subject().CN
2434
2435    def test_get_certificate_none(self):
2436        """
2437        `Connection.get_certificate` returns the local certificate.
2438
2439        If there is no certificate, it returns None.
2440        """
2441        context = Context(TLSv1_METHOD)
2442        client = Connection(context, None)
2443        cert = client.get_certificate()
2444        assert cert is None
2445
2446    def test_get_peer_cert_chain(self):
2447        """
2448        `Connection.get_peer_cert_chain` returns a list of certificates
2449        which the connected server returned for the certification verification.
2450        """
2451        chain = _create_certificate_chain()
2452        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
2453
2454        serverContext = Context(TLSv1_METHOD)
2455        serverContext.use_privatekey(skey)
2456        serverContext.use_certificate(scert)
2457        serverContext.add_extra_chain_cert(icert)
2458        serverContext.add_extra_chain_cert(cacert)
2459        server = Connection(serverContext, None)
2460        server.set_accept_state()
2461
2462        # Create the client
2463        clientContext = Context(TLSv1_METHOD)
2464        clientContext.set_verify(VERIFY_NONE, verify_cb)
2465        client = Connection(clientContext, None)
2466        client.set_connect_state()
2467
2468        interact_in_memory(client, server)
2469
2470        chain = client.get_peer_cert_chain()
2471        assert len(chain) == 3
2472        assert "Server Certificate" == chain[0].get_subject().CN
2473        assert "Intermediate Certificate" == chain[1].get_subject().CN
2474        assert "Authority Certificate" == chain[2].get_subject().CN
2475
2476    def test_get_peer_cert_chain_none(self):
2477        """
2478        `Connection.get_peer_cert_chain` returns `None` if the peer sends
2479        no certificate chain.
2480        """
2481        ctx = Context(TLSv1_METHOD)
2482        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
2483        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
2484        server = Connection(ctx, None)
2485        server.set_accept_state()
2486        client = Connection(Context(TLSv1_METHOD), None)
2487        client.set_connect_state()
2488        interact_in_memory(client, server)
2489        assert None is server.get_peer_cert_chain()
2490
2491    def test_get_session_unconnected(self):
2492        """
2493        `Connection.get_session` returns `None` when used with an object
2494        which has not been connected.
2495        """
2496        ctx = Context(TLSv1_METHOD)
2497        server = Connection(ctx, None)
2498        session = server.get_session()
2499        assert None is session
2500
2501    def test_server_get_session(self):
2502        """
2503        On the server side of a connection, `Connection.get_session` returns a
2504        `Session` instance representing the SSL session for that connection.
2505        """
2506        server, client = loopback()
2507        session = server.get_session()
2508        assert isinstance(session, Session)
2509
2510    def test_client_get_session(self):
2511        """
2512        On the client side of a connection, `Connection.get_session`
2513        returns a `Session` instance representing the SSL session for
2514        that connection.
2515        """
2516        server, client = loopback()
2517        session = client.get_session()
2518        assert isinstance(session, Session)
2519
2520    def test_set_session_wrong_args(self):
2521        """
2522        `Connection.set_session` raises `TypeError` if called with an object
2523        that is not an instance of `Session`.
2524        """
2525        ctx = Context(TLSv1_METHOD)
2526        connection = Connection(ctx, None)
2527        with pytest.raises(TypeError):
2528            connection.set_session(123)
2529        with pytest.raises(TypeError):
2530            connection.set_session("hello")
2531        with pytest.raises(TypeError):
2532            connection.set_session(object())
2533
2534    def test_client_set_session(self):
2535        """
2536        `Connection.set_session`, when used prior to a connection being
2537        established, accepts a `Session` instance and causes an attempt to
2538        re-use the session it represents when the SSL handshake is performed.
2539        """
2540        key = load_privatekey(FILETYPE_PEM, server_key_pem)
2541        cert = load_certificate(FILETYPE_PEM, server_cert_pem)
2542        ctx = Context(TLSv1_2_METHOD)
2543        ctx.use_privatekey(key)
2544        ctx.use_certificate(cert)
2545        ctx.set_session_id("unity-test")
2546
2547        def makeServer(socket):
2548            server = Connection(ctx, socket)
2549            server.set_accept_state()
2550            return server
2551
2552        originalServer, originalClient = loopback(
2553            server_factory=makeServer)
2554        originalSession = originalClient.get_session()
2555
2556        def makeClient(socket):
2557            client = loopback_client_factory(socket)
2558            client.set_session(originalSession)
2559            return client
2560        resumedServer, resumedClient = loopback(
2561            server_factory=makeServer,
2562            client_factory=makeClient)
2563
2564        # This is a proxy: in general, we have no access to any unique
2565        # identifier for the session (new enough versions of OpenSSL expose
2566        # a hash which could be usable, but "new enough" is very, very new).
2567        # Instead, exploit the fact that the master key is re-used if the
2568        # session is re-used.  As long as the master key for the two
2569        # connections is the same, the session was re-used!
2570        assert originalServer.master_key() == resumedServer.master_key()
2571
2572    def test_set_session_wrong_method(self):
2573        """
2574        If `Connection.set_session` is passed a `Session` instance associated
2575        with a context using a different SSL method than the `Connection`
2576        is using, a `OpenSSL.SSL.Error` is raised.
2577        """
2578        # Make this work on both OpenSSL 1.0.0, which doesn't support TLSv1.2
2579        # and also on OpenSSL 1.1.0 which doesn't support SSLv3. (SSL_ST_INIT
2580        # is a way to check for 1.1.0)
2581        if SSL_ST_INIT is None:
2582            v1 = TLSv1_2_METHOD
2583            v2 = TLSv1_METHOD
2584        elif hasattr(_lib, "SSLv3_method"):
2585            v1 = TLSv1_METHOD
2586            v2 = SSLv3_METHOD
2587        else:
2588            pytest.skip("Test requires either OpenSSL 1.1.0 or SSLv3")
2589
2590        key = load_privatekey(FILETYPE_PEM, server_key_pem)
2591        cert = load_certificate(FILETYPE_PEM, server_cert_pem)
2592        ctx = Context(v1)
2593        ctx.use_privatekey(key)
2594        ctx.use_certificate(cert)
2595        ctx.set_session_id("unity-test")
2596
2597        def makeServer(socket):
2598            server = Connection(ctx, socket)
2599            server.set_accept_state()
2600            return server
2601
2602        def makeOriginalClient(socket):
2603            client = Connection(Context(v1), socket)
2604            client.set_connect_state()
2605            return client
2606
2607        originalServer, originalClient = loopback(
2608            server_factory=makeServer, client_factory=makeOriginalClient)
2609        originalSession = originalClient.get_session()
2610
2611        def makeClient(socket):
2612            # Intentionally use a different, incompatible method here.
2613            client = Connection(Context(v2), socket)
2614            client.set_connect_state()
2615            client.set_session(originalSession)
2616            return client
2617
2618        with pytest.raises(Error):
2619            loopback(client_factory=makeClient, server_factory=makeServer)
2620
2621    def test_wantWriteError(self):
2622        """
2623        `Connection` methods which generate output raise
2624        `OpenSSL.SSL.WantWriteError` if writing to the connection's BIO
2625        fail indicating a should-write state.
2626        """
2627        client_socket, server_socket = socket_pair()
2628        # Fill up the client's send buffer so Connection won't be able to write
2629        # anything.  Only write a single byte at a time so we can be sure we
2630        # completely fill the buffer.  Even though the socket API is allowed to
2631        # signal a short write via its return value it seems this doesn't
2632        # always happen on all platforms (FreeBSD and OS X particular) for the
2633        # very last bit of available buffer space.
2634        msg = b"x"
2635        for i in range(1024 * 1024 * 64):
2636            try:
2637                client_socket.send(msg)
2638            except error as e:
2639                if e.errno == EWOULDBLOCK:
2640                    break
2641                raise
2642        else:
2643            pytest.fail(
2644                "Failed to fill socket buffer, cannot test BIO want write")
2645
2646        ctx = Context(TLSv1_METHOD)
2647        conn = Connection(ctx, client_socket)
2648        # Client's speak first, so make it an SSL client
2649        conn.set_connect_state()
2650        with pytest.raises(WantWriteError):
2651            conn.do_handshake()
2652
2653    # XXX want_read
2654
2655    def test_get_finished_before_connect(self):
2656        """
2657        `Connection.get_finished` returns `None` before TLS handshake
2658        is completed.
2659        """
2660        ctx = Context(TLSv1_METHOD)
2661        connection = Connection(ctx, None)
2662        assert connection.get_finished() is None
2663
2664    def test_get_peer_finished_before_connect(self):
2665        """
2666        `Connection.get_peer_finished` returns `None` before TLS handshake
2667        is completed.
2668        """
2669        ctx = Context(TLSv1_METHOD)
2670        connection = Connection(ctx, None)
2671        assert connection.get_peer_finished() is None
2672
2673    def test_get_finished(self):
2674        """
2675        `Connection.get_finished` method returns the TLS Finished message send
2676        from client, or server. Finished messages are send during
2677        TLS handshake.
2678        """
2679        server, client = loopback()
2680
2681        assert server.get_finished() is not None
2682        assert len(server.get_finished()) > 0
2683
2684    def test_get_peer_finished(self):
2685        """
2686        `Connection.get_peer_finished` method returns the TLS Finished
2687        message received from client, or server. Finished messages are send
2688        during TLS handshake.
2689        """
2690        server, client = loopback()
2691
2692        assert server.get_peer_finished() is not None
2693        assert len(server.get_peer_finished()) > 0
2694
2695    def test_tls_finished_message_symmetry(self):
2696        """
2697        The TLS Finished message send by server must be the TLS Finished
2698        message received by client.
2699
2700        The TLS Finished message send by client must be the TLS Finished
2701        message received by server.
2702        """
2703        server, client = loopback()
2704
2705        assert server.get_finished() == client.get_peer_finished()
2706        assert client.get_finished() == server.get_peer_finished()
2707
2708    def test_get_cipher_name_before_connect(self):
2709        """
2710        `Connection.get_cipher_name` returns `None` if no connection
2711        has been established.
2712        """
2713        ctx = Context(TLSv1_METHOD)
2714        conn = Connection(ctx, None)
2715        assert conn.get_cipher_name() is None
2716
2717    def test_get_cipher_name(self):
2718        """
2719        `Connection.get_cipher_name` returns a `unicode` string giving the
2720        name of the currently used cipher.
2721        """
2722        server, client = loopback()
2723        server_cipher_name, client_cipher_name = \
2724            server.get_cipher_name(), client.get_cipher_name()
2725
2726        assert isinstance(server_cipher_name, text_type)
2727        assert isinstance(client_cipher_name, text_type)
2728
2729        assert server_cipher_name == client_cipher_name
2730
2731    def test_get_cipher_version_before_connect(self):
2732        """
2733        `Connection.get_cipher_version` returns `None` if no connection
2734        has been established.
2735        """
2736        ctx = Context(TLSv1_METHOD)
2737        conn = Connection(ctx, None)
2738        assert conn.get_cipher_version() is None
2739
2740    def test_get_cipher_version(self):
2741        """
2742        `Connection.get_cipher_version` returns a `unicode` string giving
2743        the protocol name of the currently used cipher.
2744        """
2745        server, client = loopback()
2746        server_cipher_version, client_cipher_version = \
2747            server.get_cipher_version(), client.get_cipher_version()
2748
2749        assert isinstance(server_cipher_version, text_type)
2750        assert isinstance(client_cipher_version, text_type)
2751
2752        assert server_cipher_version == client_cipher_version
2753
2754    def test_get_cipher_bits_before_connect(self):
2755        """
2756        `Connection.get_cipher_bits` returns `None` if no connection has
2757        been established.
2758        """
2759        ctx = Context(TLSv1_METHOD)
2760        conn = Connection(ctx, None)
2761        assert conn.get_cipher_bits() is None
2762
2763    def test_get_cipher_bits(self):
2764        """
2765        `Connection.get_cipher_bits` returns the number of secret bits
2766        of the currently used cipher.
2767        """
2768        server, client = loopback()
2769        server_cipher_bits, client_cipher_bits = \
2770            server.get_cipher_bits(), client.get_cipher_bits()
2771
2772        assert isinstance(server_cipher_bits, int)
2773        assert isinstance(client_cipher_bits, int)
2774
2775        assert server_cipher_bits == client_cipher_bits
2776
2777    def test_get_protocol_version_name(self):
2778        """
2779        `Connection.get_protocol_version_name()` returns a string giving the
2780        protocol version of the current connection.
2781        """
2782        server, client = loopback()
2783        client_protocol_version_name = client.get_protocol_version_name()
2784        server_protocol_version_name = server.get_protocol_version_name()
2785
2786        assert isinstance(server_protocol_version_name, text_type)
2787        assert isinstance(client_protocol_version_name, text_type)
2788
2789        assert server_protocol_version_name == client_protocol_version_name
2790
2791    def test_get_protocol_version(self):
2792        """
2793        `Connection.get_protocol_version()` returns an integer
2794        giving the protocol version of the current connection.
2795        """
2796        server, client = loopback()
2797        client_protocol_version = client.get_protocol_version()
2798        server_protocol_version = server.get_protocol_version()
2799
2800        assert isinstance(server_protocol_version, int)
2801        assert isinstance(client_protocol_version, int)
2802
2803        assert server_protocol_version == client_protocol_version
2804
2805    def test_wantReadError(self):
2806        """
2807        `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are
2808        no bytes available to be read from the BIO.
2809        """
2810        ctx = Context(TLSv1_METHOD)
2811        conn = Connection(ctx, None)
2812        with pytest.raises(WantReadError):
2813            conn.bio_read(1024)
2814
2815    @pytest.mark.parametrize('bufsize', [1.0, None, object(), 'bufsize'])
2816    def test_bio_read_wrong_args(self, bufsize):
2817        """
2818        `Connection.bio_read` raises `TypeError` if passed a non-integer
2819        argument.
2820        """
2821        ctx = Context(TLSv1_METHOD)
2822        conn = Connection(ctx, None)
2823        with pytest.raises(TypeError):
2824            conn.bio_read(bufsize)
2825
2826    def test_buffer_size(self):
2827        """
2828        `Connection.bio_read` accepts an integer giving the maximum number
2829        of bytes to read and return.
2830        """
2831        ctx = Context(TLSv1_METHOD)
2832        conn = Connection(ctx, None)
2833        conn.set_connect_state()
2834        try:
2835            conn.do_handshake()
2836        except WantReadError:
2837            pass
2838        data = conn.bio_read(2)
2839        assert 2 == len(data)
2840
2841    @skip_if_py3
2842    def test_buffer_size_long(self):
2843        """
2844        On Python 2 `Connection.bio_read` accepts values of type `long` as
2845        well as `int`.
2846        """
2847        ctx = Context(TLSv1_METHOD)
2848        conn = Connection(ctx, None)
2849        conn.set_connect_state()
2850        try:
2851            conn.do_handshake()
2852        except WantReadError:
2853            pass
2854        data = conn.bio_read(long(2))
2855        assert 2 == len(data)
2856
2857
2858class TestConnectionGetCipherList(object):
2859    """
2860    Tests for `Connection.get_cipher_list`.
2861    """
2862    def test_result(self):
2863        """
2864        `Connection.get_cipher_list` returns a list of `bytes` giving the
2865        names of the ciphers which might be used.
2866        """
2867        connection = Connection(Context(TLSv1_METHOD), None)
2868        ciphers = connection.get_cipher_list()
2869        assert isinstance(ciphers, list)
2870        for cipher in ciphers:
2871            assert isinstance(cipher, str)
2872
2873
2874class VeryLarge(bytes):
2875    """
2876    Mock object so that we don't have to allocate 2**31 bytes
2877    """
2878    def __len__(self):
2879        return 2**31
2880
2881
2882class TestConnectionSend(object):
2883    """
2884    Tests for `Connection.send`.
2885    """
2886    def test_wrong_args(self):
2887        """
2888        When called with arguments other than string argument for its first
2889        parameter, `Connection.send` raises `TypeError`.
2890        """
2891        connection = Connection(Context(TLSv1_METHOD), None)
2892        with pytest.raises(TypeError):
2893            connection.send(object())
2894
2895    def test_short_bytes(self):
2896        """
2897        When passed a short byte string, `Connection.send` transmits all of it
2898        and returns the number of bytes sent.
2899        """
2900        server, client = loopback()
2901        count = server.send(b'xy')
2902        assert count == 2
2903        assert client.recv(2) == b'xy'
2904
2905    def test_text(self):
2906        """
2907        When passed a text, `Connection.send` transmits all of it and
2908        returns the number of bytes sent. It also raises a DeprecationWarning.
2909        """
2910        server, client = loopback()
2911        with pytest.warns(DeprecationWarning) as w:
2912            simplefilter("always")
2913            count = server.send(b"xy".decode("ascii"))
2914            assert (
2915                "{0} for buf is no longer accepted, use bytes".format(
2916                    WARNING_TYPE_EXPECTED
2917                ) == str(w[-1].message))
2918        assert count == 2
2919        assert client.recv(2) == b'xy'
2920
2921    def test_short_memoryview(self):
2922        """
2923        When passed a memoryview onto a small number of bytes,
2924        `Connection.send` transmits all of them and returns the number
2925        of bytes sent.
2926        """
2927        server, client = loopback()
2928        count = server.send(memoryview(b'xy'))
2929        assert count == 2
2930        assert client.recv(2) == b'xy'
2931
2932    @skip_if_py3
2933    def test_short_buffer(self):
2934        """
2935        When passed a buffer containing a small number of bytes,
2936        `Connection.send` transmits all of them and returns the number
2937        of bytes sent.
2938        """
2939        server, client = loopback()
2940        count = server.send(buffer(b'xy'))
2941        assert count == 2
2942        assert client.recv(2) == b'xy'
2943
2944    @pytest.mark.skipif(
2945        sys.maxsize < 2**31,
2946        reason="sys.maxsize < 2**31 - test requires 64 bit"
2947    )
2948    def test_buf_too_large(self):
2949        """
2950        When passed a buffer containing >= 2**31 bytes,
2951        `Connection.send` bails out as SSL_write only
2952        accepts an int for the buffer length.
2953        """
2954        connection = Connection(Context(TLSv1_METHOD), None)
2955        with pytest.raises(ValueError) as exc_info:
2956            connection.send(VeryLarge())
2957        exc_info.match(r"Cannot send more than .+ bytes at once")
2958
2959
2960def _make_memoryview(size):
2961    """
2962    Create a new ``memoryview`` wrapped around a ``bytearray`` of the given
2963    size.
2964    """
2965    return memoryview(bytearray(size))
2966
2967
2968class TestConnectionRecvInto(object):
2969    """
2970    Tests for `Connection.recv_into`.
2971    """
2972    def _no_length_test(self, factory):
2973        """
2974        Assert that when the given buffer is passed to `Connection.recv_into`,
2975        whatever bytes are available to be received that fit into that buffer
2976        are written into that buffer.
2977        """
2978        output_buffer = factory(5)
2979
2980        server, client = loopback()
2981        server.send(b'xy')
2982
2983        assert client.recv_into(output_buffer) == 2
2984        assert output_buffer == bytearray(b'xy\x00\x00\x00')
2985
2986    def test_bytearray_no_length(self):
2987        """
2988        `Connection.recv_into` can be passed a `bytearray` instance and data
2989        in the receive buffer is written to it.
2990        """
2991        self._no_length_test(bytearray)
2992
2993    def _respects_length_test(self, factory):
2994        """
2995        Assert that when the given buffer is passed to `Connection.recv_into`
2996        along with a value for `nbytes` that is less than the size of that
2997        buffer, only `nbytes` bytes are written into the buffer.
2998        """
2999        output_buffer = factory(10)
3000
3001        server, client = loopback()
3002        server.send(b'abcdefghij')
3003
3004        assert client.recv_into(output_buffer, 5) == 5
3005        assert output_buffer == bytearray(b'abcde\x00\x00\x00\x00\x00')
3006
3007    def test_bytearray_respects_length(self):
3008        """
3009        When called with a `bytearray` instance, `Connection.recv_into`
3010        respects the `nbytes` parameter and doesn't copy in more than that
3011        number of bytes.
3012        """
3013        self._respects_length_test(bytearray)
3014
3015    def _doesnt_overfill_test(self, factory):
3016        """
3017        Assert that if there are more bytes available to be read from the
3018        receive buffer than would fit into the buffer passed to
3019        `Connection.recv_into`, only as many as fit are written into it.
3020        """
3021        output_buffer = factory(5)
3022
3023        server, client = loopback()
3024        server.send(b'abcdefghij')
3025
3026        assert client.recv_into(output_buffer) == 5
3027        assert output_buffer == bytearray(b'abcde')
3028        rest = client.recv(5)
3029        assert b'fghij' == rest
3030
3031    def test_bytearray_doesnt_overfill(self):
3032        """
3033        When called with a `bytearray` instance, `Connection.recv_into`
3034        respects the size of the array and doesn't write more bytes into it
3035        than will fit.
3036        """
3037        self._doesnt_overfill_test(bytearray)
3038
3039    def test_bytearray_really_doesnt_overfill(self):
3040        """
3041        When called with a `bytearray` instance and an `nbytes` value that is
3042        too large, `Connection.recv_into` respects the size of the array and
3043        not the `nbytes` value and doesn't write more bytes into the buffer
3044        than will fit.
3045        """
3046        self._doesnt_overfill_test(bytearray)
3047
3048    def test_peek(self):
3049        server, client = loopback()
3050        server.send(b'xy')
3051
3052        for _ in range(2):
3053            output_buffer = bytearray(5)
3054            assert client.recv_into(output_buffer, flags=MSG_PEEK) == 2
3055            assert output_buffer == bytearray(b'xy\x00\x00\x00')
3056
3057    def test_memoryview_no_length(self):
3058        """
3059        `Connection.recv_into` can be passed a `memoryview` instance and data
3060        in the receive buffer is written to it.
3061        """
3062        self._no_length_test(_make_memoryview)
3063
3064    def test_memoryview_respects_length(self):
3065        """
3066        When called with a `memoryview` instance, `Connection.recv_into`
3067        respects the ``nbytes`` parameter and doesn't copy more than that
3068        number of bytes in.
3069        """
3070        self._respects_length_test(_make_memoryview)
3071
3072    def test_memoryview_doesnt_overfill(self):
3073        """
3074        When called with a `memoryview` instance, `Connection.recv_into`
3075        respects the size of the array and doesn't write more bytes into it
3076        than will fit.
3077        """
3078        self._doesnt_overfill_test(_make_memoryview)
3079
3080    def test_memoryview_really_doesnt_overfill(self):
3081        """
3082        When called with a `memoryview` instance and an `nbytes` value that is
3083        too large, `Connection.recv_into` respects the size of the array and
3084        not the `nbytes` value and doesn't write more bytes into the buffer
3085        than will fit.
3086        """
3087        self._doesnt_overfill_test(_make_memoryview)
3088
3089
3090class TestConnectionSendall(object):
3091    """
3092    Tests for `Connection.sendall`.
3093    """
3094    def test_wrong_args(self):
3095        """
3096        When called with arguments other than a string argument for its first
3097        parameter, `Connection.sendall` raises `TypeError`.
3098        """
3099        connection = Connection(Context(TLSv1_METHOD), None)
3100        with pytest.raises(TypeError):
3101            connection.sendall(object())
3102
3103    def test_short(self):
3104        """
3105        `Connection.sendall` transmits all of the bytes in the string
3106        passed to it.
3107        """
3108        server, client = loopback()
3109        server.sendall(b'x')
3110        assert client.recv(1) == b'x'
3111
3112    def test_text(self):
3113        """
3114        `Connection.sendall` transmits all the content in the string passed
3115        to it, raising a DeprecationWarning in case of this being a text.
3116        """
3117        server, client = loopback()
3118        with pytest.warns(DeprecationWarning) as w:
3119            simplefilter("always")
3120            server.sendall(b"x".decode("ascii"))
3121            assert (
3122                "{0} for buf is no longer accepted, use bytes".format(
3123                    WARNING_TYPE_EXPECTED
3124                ) == str(w[-1].message))
3125        assert client.recv(1) == b"x"
3126
3127    def test_short_memoryview(self):
3128        """
3129        When passed a memoryview onto a small number of bytes,
3130        `Connection.sendall` transmits all of them.
3131        """
3132        server, client = loopback()
3133        server.sendall(memoryview(b'x'))
3134        assert client.recv(1) == b'x'
3135
3136    @skip_if_py3
3137    def test_short_buffers(self):
3138        """
3139        When passed a buffer containing a small number of bytes,
3140        `Connection.sendall` transmits all of them.
3141        """
3142        server, client = loopback()
3143        server.sendall(buffer(b'x'))
3144        assert client.recv(1) == b'x'
3145
3146    def test_long(self):
3147        """
3148        `Connection.sendall` transmits all the bytes in the string passed to it
3149        even if this requires multiple calls of an underlying write function.
3150        """
3151        server, client = loopback()
3152        # Should be enough, underlying SSL_write should only do 16k at a time.
3153        # On Windows, after 32k of bytes the write will block (forever
3154        # - because no one is yet reading).
3155        message = b'x' * (1024 * 32 - 1) + b'y'
3156        server.sendall(message)
3157        accum = []
3158        received = 0
3159        while received < len(message):
3160            data = client.recv(1024)
3161            accum.append(data)
3162            received += len(data)
3163        assert message == b''.join(accum)
3164
3165    def test_closed(self):
3166        """
3167        If the underlying socket is closed, `Connection.sendall` propagates the
3168        write error from the low level write call.
3169        """
3170        server, client = loopback()
3171        server.sock_shutdown(2)
3172        with pytest.raises(SysCallError) as err:
3173            server.sendall(b"hello, world")
3174        if platform == "win32":
3175            assert err.value.args[0] == ESHUTDOWN
3176        else:
3177            assert err.value.args[0] == EPIPE
3178
3179
3180class TestConnectionRenegotiate(object):
3181    """
3182    Tests for SSL renegotiation APIs.
3183    """
3184    def test_total_renegotiations(self):
3185        """
3186        `Connection.total_renegotiations` returns `0` before any renegotiations
3187        have happened.
3188        """
3189        connection = Connection(Context(TLSv1_METHOD), None)
3190        assert connection.total_renegotiations() == 0
3191
3192    def test_renegotiate(self):
3193        """
3194        Go through a complete renegotiation cycle.
3195        """
3196        server, client = loopback(
3197            lambda s: loopback_server_factory(s, TLSv1_2_METHOD),
3198            lambda s: loopback_client_factory(s, TLSv1_2_METHOD),
3199        )
3200
3201        server.send(b"hello world")
3202
3203        assert b"hello world" == client.recv(len(b"hello world"))
3204
3205        assert 0 == server.total_renegotiations()
3206        assert False is server.renegotiate_pending()
3207
3208        assert True is server.renegotiate()
3209
3210        assert True is server.renegotiate_pending()
3211
3212        server.setblocking(False)
3213        client.setblocking(False)
3214
3215        client.do_handshake()
3216        server.do_handshake()
3217
3218        assert 1 == server.total_renegotiations()
3219        while False is server.renegotiate_pending():
3220            pass
3221
3222
3223class TestError(object):
3224    """
3225    Unit tests for `OpenSSL.SSL.Error`.
3226    """
3227    def test_type(self):
3228        """
3229        `Error` is an exception type.
3230        """
3231        assert issubclass(Error, Exception)
3232        assert Error.__name__ == 'Error'
3233
3234
3235class TestConstants(object):
3236    """
3237    Tests for the values of constants exposed in `OpenSSL.SSL`.
3238
3239    These are values defined by OpenSSL intended only to be used as flags to
3240    OpenSSL APIs.  The only assertions it seems can be made about them is
3241    their values.
3242    """
3243    @pytest.mark.skipif(
3244        OP_NO_QUERY_MTU is None,
3245        reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old"
3246    )
3247    def test_op_no_query_mtu(self):
3248        """
3249        The value of `OpenSSL.SSL.OP_NO_QUERY_MTU` is 0x1000, the value
3250        of `SSL_OP_NO_QUERY_MTU` defined by `openssl/ssl.h`.
3251        """
3252        assert OP_NO_QUERY_MTU == 0x1000
3253
3254    @pytest.mark.skipif(
3255        OP_COOKIE_EXCHANGE is None,
3256        reason="OP_COOKIE_EXCHANGE unavailable - "
3257        "OpenSSL version may be too old"
3258    )
3259    def test_op_cookie_exchange(self):
3260        """
3261        The value of `OpenSSL.SSL.OP_COOKIE_EXCHANGE` is 0x2000, the
3262        value of `SSL_OP_COOKIE_EXCHANGE` defined by `openssl/ssl.h`.
3263        """
3264        assert OP_COOKIE_EXCHANGE == 0x2000
3265
3266    @pytest.mark.skipif(
3267        OP_NO_TICKET is None,
3268        reason="OP_NO_TICKET unavailable - OpenSSL version may be too old"
3269    )
3270    def test_op_no_ticket(self):
3271        """
3272        The value of `OpenSSL.SSL.OP_NO_TICKET` is 0x4000, the value of
3273        `SSL_OP_NO_TICKET` defined by `openssl/ssl.h`.
3274        """
3275        assert OP_NO_TICKET == 0x4000
3276
3277    @pytest.mark.skipif(
3278        OP_NO_COMPRESSION is None,
3279        reason="OP_NO_COMPRESSION unavailable - OpenSSL version may be too old"
3280    )
3281    def test_op_no_compression(self):
3282        """
3283        The value of `OpenSSL.SSL.OP_NO_COMPRESSION` is 0x20000, the
3284        value of `SSL_OP_NO_COMPRESSION` defined by `openssl/ssl.h`.
3285        """
3286        assert OP_NO_COMPRESSION == 0x20000
3287
3288    def test_sess_cache_off(self):
3289        """
3290        The value of `OpenSSL.SSL.SESS_CACHE_OFF` 0x0, the value of
3291        `SSL_SESS_CACHE_OFF` defined by `openssl/ssl.h`.
3292        """
3293        assert 0x0 == SESS_CACHE_OFF
3294
3295    def test_sess_cache_client(self):
3296        """
3297        The value of `OpenSSL.SSL.SESS_CACHE_CLIENT` 0x1, the value of
3298        `SSL_SESS_CACHE_CLIENT` defined by `openssl/ssl.h`.
3299        """
3300        assert 0x1 == SESS_CACHE_CLIENT
3301
3302    def test_sess_cache_server(self):
3303        """
3304        The value of `OpenSSL.SSL.SESS_CACHE_SERVER` 0x2, the value of
3305        `SSL_SESS_CACHE_SERVER` defined by `openssl/ssl.h`.
3306        """
3307        assert 0x2 == SESS_CACHE_SERVER
3308
3309    def test_sess_cache_both(self):
3310        """
3311        The value of `OpenSSL.SSL.SESS_CACHE_BOTH` 0x3, the value of
3312        `SSL_SESS_CACHE_BOTH` defined by `openssl/ssl.h`.
3313        """
3314        assert 0x3 == SESS_CACHE_BOTH
3315
3316    def test_sess_cache_no_auto_clear(self):
3317        """
3318        The value of `OpenSSL.SSL.SESS_CACHE_NO_AUTO_CLEAR` 0x80, the
3319        value of `SSL_SESS_CACHE_NO_AUTO_CLEAR` defined by
3320        `openssl/ssl.h`.
3321        """
3322        assert 0x80 == SESS_CACHE_NO_AUTO_CLEAR
3323
3324    def test_sess_cache_no_internal_lookup(self):
3325        """
3326        The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_LOOKUP` 0x100,
3327        the value of `SSL_SESS_CACHE_NO_INTERNAL_LOOKUP` defined by
3328        `openssl/ssl.h`.
3329        """
3330        assert 0x100 == SESS_CACHE_NO_INTERNAL_LOOKUP
3331
3332    def test_sess_cache_no_internal_store(self):
3333        """
3334        The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_STORE` 0x200,
3335        the value of `SSL_SESS_CACHE_NO_INTERNAL_STORE` defined by
3336        `openssl/ssl.h`.
3337        """
3338        assert 0x200 == SESS_CACHE_NO_INTERNAL_STORE
3339
3340    def test_sess_cache_no_internal(self):
3341        """
3342        The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL` 0x300, the
3343        value of `SSL_SESS_CACHE_NO_INTERNAL` defined by
3344        `openssl/ssl.h`.
3345        """
3346        assert 0x300 == SESS_CACHE_NO_INTERNAL
3347
3348
3349class TestMemoryBIO(object):
3350    """
3351    Tests for `OpenSSL.SSL.Connection` using a memory BIO.
3352    """
3353    def _server(self, sock):
3354        """
3355        Create a new server-side SSL `Connection` object wrapped around `sock`.
3356        """
3357        # Create the server side Connection.  This is mostly setup boilerplate
3358        # - use TLSv1, use a particular certificate, etc.
3359        server_ctx = Context(TLSv1_METHOD)
3360        server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE)
3361        server_ctx.set_verify(
3362            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE,
3363            verify_cb
3364        )
3365        server_store = server_ctx.get_cert_store()
3366        server_ctx.use_privatekey(
3367            load_privatekey(FILETYPE_PEM, server_key_pem))
3368        server_ctx.use_certificate(
3369            load_certificate(FILETYPE_PEM, server_cert_pem))
3370        server_ctx.check_privatekey()
3371        server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
3372        # Here the Connection is actually created.  If None is passed as the
3373        # 2nd parameter, it indicates a memory BIO should be created.
3374        server_conn = Connection(server_ctx, sock)
3375        server_conn.set_accept_state()
3376        return server_conn
3377
3378    def _client(self, sock):
3379        """
3380        Create a new client-side SSL `Connection` object wrapped around `sock`.
3381        """
3382        # Now create the client side Connection.  Similar boilerplate to the
3383        # above.
3384        client_ctx = Context(TLSv1_METHOD)
3385        client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE)
3386        client_ctx.set_verify(
3387            VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE,
3388            verify_cb
3389        )
3390        client_store = client_ctx.get_cert_store()
3391        client_ctx.use_privatekey(
3392            load_privatekey(FILETYPE_PEM, client_key_pem))
3393        client_ctx.use_certificate(
3394            load_certificate(FILETYPE_PEM, client_cert_pem))
3395        client_ctx.check_privatekey()
3396        client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
3397        client_conn = Connection(client_ctx, sock)
3398        client_conn.set_connect_state()
3399        return client_conn
3400
3401    def test_memory_connect(self):
3402        """
3403        Two `Connection`s which use memory BIOs can be manually connected by
3404        reading from the output of each and writing those bytes to the input of
3405        the other and in this way establish a connection and exchange
3406        application-level bytes with each other.
3407        """
3408        server_conn = self._server(None)
3409        client_conn = self._client(None)
3410
3411        # There should be no key or nonces yet.
3412        assert server_conn.master_key() is None
3413        assert server_conn.client_random() is None
3414        assert server_conn.server_random() is None
3415
3416        # First, the handshake needs to happen.  We'll deliver bytes back and
3417        # forth between the client and server until neither of them feels like
3418        # speaking any more.
3419        assert interact_in_memory(client_conn, server_conn) is None
3420
3421        # Now that the handshake is done, there should be a key and nonces.
3422        assert server_conn.master_key() is not None
3423        assert server_conn.client_random() is not None
3424        assert server_conn.server_random() is not None
3425        assert server_conn.client_random() == client_conn.client_random()
3426        assert server_conn.server_random() == client_conn.server_random()
3427        assert server_conn.client_random() != server_conn.server_random()
3428        assert client_conn.client_random() != client_conn.server_random()
3429
3430        # Export key material for other uses.
3431        cekm = client_conn.export_keying_material(b'LABEL', 32)
3432        sekm = server_conn.export_keying_material(b'LABEL', 32)
3433        assert cekm is not None
3434        assert sekm is not None
3435        assert cekm == sekm
3436        assert len(sekm) == 32
3437
3438        # Export key material for other uses with additional context.
3439        cekmc = client_conn.export_keying_material(b'LABEL', 32, b'CONTEXT')
3440        sekmc = server_conn.export_keying_material(b'LABEL', 32, b'CONTEXT')
3441        assert cekmc is not None
3442        assert sekmc is not None
3443        assert cekmc == sekmc
3444        assert cekmc != cekm
3445        assert sekmc != sekm
3446        # Export with alternate label
3447        cekmt = client_conn.export_keying_material(b'test', 32, b'CONTEXT')
3448        sekmt = server_conn.export_keying_material(b'test', 32, b'CONTEXT')
3449        assert cekmc != cekmt
3450        assert sekmc != sekmt
3451
3452        # Here are the bytes we'll try to send.
3453        important_message = b'One if by land, two if by sea.'
3454
3455        server_conn.write(important_message)
3456        assert (
3457            interact_in_memory(client_conn, server_conn) ==
3458            (client_conn, important_message))
3459
3460        client_conn.write(important_message[::-1])
3461        assert (
3462            interact_in_memory(client_conn, server_conn) ==
3463            (server_conn, important_message[::-1]))
3464
3465    def test_socket_connect(self):
3466        """
3467        Just like `test_memory_connect` but with an actual socket.
3468
3469        This is primarily to rule out the memory BIO code as the source of any
3470        problems encountered while passing data over a `Connection` (if
3471        this test fails, there must be a problem outside the memory BIO code,
3472        as no memory BIO is involved here).  Even though this isn't a memory
3473        BIO test, it's convenient to have it here.
3474        """
3475        server_conn, client_conn = loopback()
3476
3477        important_message = b"Help me Obi Wan Kenobi, you're my only hope."
3478        client_conn.send(important_message)
3479        msg = server_conn.recv(1024)
3480        assert msg == important_message
3481
3482        # Again in the other direction, just for fun.
3483        important_message = important_message[::-1]
3484        server_conn.send(important_message)
3485        msg = client_conn.recv(1024)
3486        assert msg == important_message
3487
3488    def test_socket_overrides_memory(self):
3489        """
3490        Test that `OpenSSL.SSL.bio_read` and `OpenSSL.SSL.bio_write` don't
3491        work on `OpenSSL.SSL.Connection`() that use sockets.
3492        """
3493        context = Context(TLSv1_METHOD)
3494        client = socket()
3495        clientSSL = Connection(context, client)
3496        with pytest.raises(TypeError):
3497            clientSSL.bio_read(100)
3498        with pytest.raises(TypeError):
3499            clientSSL.bio_write("foo")
3500        with pytest.raises(TypeError):
3501            clientSSL.bio_shutdown()
3502
3503    def test_outgoing_overflow(self):
3504        """
3505        If more bytes than can be written to the memory BIO are passed to
3506        `Connection.send` at once, the number of bytes which were written is
3507        returned and that many bytes from the beginning of the input can be
3508        read from the other end of the connection.
3509        """
3510        server = self._server(None)
3511        client = self._client(None)
3512
3513        interact_in_memory(client, server)
3514
3515        size = 2 ** 15
3516        sent = client.send(b"x" * size)
3517        # Sanity check.  We're trying to test what happens when the entire
3518        # input can't be sent.  If the entire input was sent, this test is
3519        # meaningless.
3520        assert sent < size
3521
3522        receiver, received = interact_in_memory(client, server)
3523        assert receiver is server
3524
3525        # We can rely on all of these bytes being received at once because
3526        # loopback passes 2 ** 16 to recv - more than 2 ** 15.
3527        assert len(received) == sent
3528
3529    def test_shutdown(self):
3530        """
3531        `Connection.bio_shutdown` signals the end of the data stream
3532        from which the `Connection` reads.
3533        """
3534        server = self._server(None)
3535        server.bio_shutdown()
3536        with pytest.raises(Error) as err:
3537            server.recv(1024)
3538        # We don't want WantReadError or ZeroReturnError or anything - it's a
3539        # handshake failure.
3540        assert type(err.value) in [Error, SysCallError]
3541
3542    def test_unexpected_EOF(self):
3543        """
3544        If the connection is lost before an orderly SSL shutdown occurs,
3545        `OpenSSL.SSL.SysCallError` is raised with a message of
3546        "Unexpected EOF".
3547        """
3548        server_conn, client_conn = loopback()
3549        client_conn.sock_shutdown(SHUT_RDWR)
3550        with pytest.raises(SysCallError) as err:
3551            server_conn.recv(1024)
3552        assert err.value.args == (-1, "Unexpected EOF")
3553
3554    def _check_client_ca_list(self, func):
3555        """
3556        Verify the return value of the `get_client_ca_list` method for
3557        server and client connections.
3558
3559        :param func: A function which will be called with the server context
3560            before the client and server are connected to each other.  This
3561            function should specify a list of CAs for the server to send to the
3562            client and return that same list.  The list will be used to verify
3563            that `get_client_ca_list` returns the proper value at
3564            various times.
3565        """
3566        server = self._server(None)
3567        client = self._client(None)
3568        assert client.get_client_ca_list() == []
3569        assert server.get_client_ca_list() == []
3570        ctx = server.get_context()
3571        expected = func(ctx)
3572        assert client.get_client_ca_list() == []
3573        assert server.get_client_ca_list() == expected
3574        interact_in_memory(client, server)
3575        assert client.get_client_ca_list() == expected
3576        assert server.get_client_ca_list() == expected
3577
3578    def test_set_client_ca_list_errors(self):
3579        """
3580        `Context.set_client_ca_list` raises a `TypeError` if called with a
3581        non-list or a list that contains objects other than X509Names.
3582        """
3583        ctx = Context(TLSv1_METHOD)
3584        with pytest.raises(TypeError):
3585            ctx.set_client_ca_list("spam")
3586        with pytest.raises(TypeError):
3587            ctx.set_client_ca_list(["spam"])
3588
3589    def test_set_empty_ca_list(self):
3590        """
3591        If passed an empty list, `Context.set_client_ca_list` configures the
3592        context to send no CA names to the client and, on both the server and
3593        client sides, `Connection.get_client_ca_list` returns an empty list
3594        after the connection is set up.
3595        """
3596        def no_ca(ctx):
3597            ctx.set_client_ca_list([])
3598            return []
3599        self._check_client_ca_list(no_ca)
3600
3601    def test_set_one_ca_list(self):
3602        """
3603        If passed a list containing a single X509Name,
3604        `Context.set_client_ca_list` configures the context to send
3605        that CA name to the client and, on both the server and client sides,
3606        `Connection.get_client_ca_list` returns a list containing that
3607        X509Name after the connection is set up.
3608        """
3609        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3610        cadesc = cacert.get_subject()
3611
3612        def single_ca(ctx):
3613            ctx.set_client_ca_list([cadesc])
3614            return [cadesc]
3615        self._check_client_ca_list(single_ca)
3616
3617    def test_set_multiple_ca_list(self):
3618        """
3619        If passed a list containing multiple X509Name objects,
3620        `Context.set_client_ca_list` configures the context to send
3621        those CA names to the client and, on both the server and client sides,
3622        `Connection.get_client_ca_list` returns a list containing those
3623        X509Names after the connection is set up.
3624        """
3625        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3626        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3627
3628        sedesc = secert.get_subject()
3629        cldesc = clcert.get_subject()
3630
3631        def multiple_ca(ctx):
3632            L = [sedesc, cldesc]
3633            ctx.set_client_ca_list(L)
3634            return L
3635        self._check_client_ca_list(multiple_ca)
3636
3637    def test_reset_ca_list(self):
3638        """
3639        If called multiple times, only the X509Names passed to the final call
3640        of `Context.set_client_ca_list` are used to configure the CA
3641        names sent to the client.
3642        """
3643        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3644        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3645        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3646
3647        cadesc = cacert.get_subject()
3648        sedesc = secert.get_subject()
3649        cldesc = clcert.get_subject()
3650
3651        def changed_ca(ctx):
3652            ctx.set_client_ca_list([sedesc, cldesc])
3653            ctx.set_client_ca_list([cadesc])
3654            return [cadesc]
3655        self._check_client_ca_list(changed_ca)
3656
3657    def test_mutated_ca_list(self):
3658        """
3659        If the list passed to `Context.set_client_ca_list` is mutated
3660        afterwards, this does not affect the list of CA names sent to the
3661        client.
3662        """
3663        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3664        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3665
3666        cadesc = cacert.get_subject()
3667        sedesc = secert.get_subject()
3668
3669        def mutated_ca(ctx):
3670            L = [cadesc]
3671            ctx.set_client_ca_list([cadesc])
3672            L.append(sedesc)
3673            return [cadesc]
3674        self._check_client_ca_list(mutated_ca)
3675
3676    def test_add_client_ca_wrong_args(self):
3677        """
3678        `Context.add_client_ca` raises `TypeError` if called with
3679        a non-X509 object.
3680        """
3681        ctx = Context(TLSv1_METHOD)
3682        with pytest.raises(TypeError):
3683            ctx.add_client_ca("spam")
3684
3685    def test_one_add_client_ca(self):
3686        """
3687        A certificate's subject can be added as a CA to be sent to the client
3688        with `Context.add_client_ca`.
3689        """
3690        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3691        cadesc = cacert.get_subject()
3692
3693        def single_ca(ctx):
3694            ctx.add_client_ca(cacert)
3695            return [cadesc]
3696        self._check_client_ca_list(single_ca)
3697
3698    def test_multiple_add_client_ca(self):
3699        """
3700        Multiple CA names can be sent to the client by calling
3701        `Context.add_client_ca` with multiple X509 objects.
3702        """
3703        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3704        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3705
3706        cadesc = cacert.get_subject()
3707        sedesc = secert.get_subject()
3708
3709        def multiple_ca(ctx):
3710            ctx.add_client_ca(cacert)
3711            ctx.add_client_ca(secert)
3712            return [cadesc, sedesc]
3713        self._check_client_ca_list(multiple_ca)
3714
3715    def test_set_and_add_client_ca(self):
3716        """
3717        A call to `Context.set_client_ca_list` followed by a call to
3718        `Context.add_client_ca` results in using the CA names from the
3719        first call and the CA name from the second call.
3720        """
3721        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3722        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3723        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3724
3725        cadesc = cacert.get_subject()
3726        sedesc = secert.get_subject()
3727        cldesc = clcert.get_subject()
3728
3729        def mixed_set_add_ca(ctx):
3730            ctx.set_client_ca_list([cadesc, sedesc])
3731            ctx.add_client_ca(clcert)
3732            return [cadesc, sedesc, cldesc]
3733        self._check_client_ca_list(mixed_set_add_ca)
3734
3735    def test_set_after_add_client_ca(self):
3736        """
3737        A call to `Context.set_client_ca_list` after a call to
3738        `Context.add_client_ca` replaces the CA name specified by the
3739        former call with the names specified by the latter call.
3740        """
3741        cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
3742        secert = load_certificate(FILETYPE_PEM, server_cert_pem)
3743        clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
3744
3745        cadesc = cacert.get_subject()
3746        sedesc = secert.get_subject()
3747
3748        def set_replaces_add_ca(ctx):
3749            ctx.add_client_ca(clcert)
3750            ctx.set_client_ca_list([cadesc])
3751            ctx.add_client_ca(secert)
3752            return [cadesc, sedesc]
3753        self._check_client_ca_list(set_replaces_add_ca)
3754
3755
3756class TestInfoConstants(object):
3757    """
3758    Tests for assorted constants exposed for use in info callbacks.
3759    """
3760    def test_integers(self):
3761        """
3762        All of the info constants are integers.
3763
3764        This is a very weak test.  It would be nice to have one that actually
3765        verifies that as certain info events happen, the value passed to the
3766        info callback matches up with the constant exposed by OpenSSL.SSL.
3767        """
3768        for const in [
3769            SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK,
3770            SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT,
3771            SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP,
3772            SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT,
3773            SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE
3774        ]:
3775            assert isinstance(const, int)
3776
3777        # These constants don't exist on OpenSSL 1.1.0
3778        for const in [
3779            SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE
3780        ]:
3781            assert const is None or isinstance(const, int)
3782
3783
3784class TestRequires(object):
3785    """
3786    Tests for the decorator factory used to conditionally raise
3787    NotImplementedError when older OpenSSLs are used.
3788    """
3789    def test_available(self):
3790        """
3791        When the OpenSSL functionality is available the decorated functions
3792        work appropriately.
3793        """
3794        feature_guard = _make_requires(True, "Error text")
3795        results = []
3796
3797        @feature_guard
3798        def inner():
3799            results.append(True)
3800            return True
3801
3802        assert inner() is True
3803        assert [True] == results
3804
3805    def test_unavailable(self):
3806        """
3807        When the OpenSSL functionality is not available the decorated function
3808        does not execute and NotImplementedError is raised.
3809        """
3810        feature_guard = _make_requires(False, "Error text")
3811
3812        @feature_guard
3813        def inner():  # pragma: nocover
3814            pytest.fail("Should not be called")
3815
3816        with pytest.raises(NotImplementedError) as e:
3817            inner()
3818
3819        assert "Error text" in str(e.value)
3820
3821
3822class TestOCSP(object):
3823    """
3824    Tests for PyOpenSSL's OCSP stapling support.
3825    """
3826    sample_ocsp_data = b"this is totally ocsp data"
3827
3828    def _client_connection(self, callback, data, request_ocsp=True):
3829        """
3830        Builds a client connection suitable for using OCSP.
3831
3832        :param callback: The callback to register for OCSP.
3833        :param data: The opaque data object that will be handed to the
3834            OCSP callback.
3835        :param request_ocsp: Whether the client will actually ask for OCSP
3836            stapling. Useful for testing only.
3837        """
3838        ctx = Context(SSLv23_METHOD)
3839        ctx.set_ocsp_client_callback(callback, data)
3840        client = Connection(ctx)
3841
3842        if request_ocsp:
3843            client.request_ocsp()
3844
3845        client.set_connect_state()
3846        return client
3847
3848    def _server_connection(self, callback, data):
3849        """
3850        Builds a server connection suitable for using OCSP.
3851
3852        :param callback: The callback to register for OCSP.
3853        :param data: The opaque data object that will be handed to the
3854            OCSP callback.
3855        """
3856        ctx = Context(SSLv23_METHOD)
3857        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
3858        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
3859        ctx.set_ocsp_server_callback(callback, data)
3860        server = Connection(ctx)
3861        server.set_accept_state()
3862        return server
3863
3864    def test_callbacks_arent_called_by_default(self):
3865        """
3866        If both the client and the server have registered OCSP callbacks, but
3867        the client does not send the OCSP request, neither callback gets
3868        called.
3869        """
3870        def ocsp_callback(*args, **kwargs):  # pragma: nocover
3871            pytest.fail("Should not be called")
3872
3873        client = self._client_connection(
3874            callback=ocsp_callback, data=None, request_ocsp=False
3875        )
3876        server = self._server_connection(callback=ocsp_callback, data=None)
3877        handshake_in_memory(client, server)
3878
3879    def test_client_negotiates_without_server(self):
3880        """
3881        If the client wants to do OCSP but the server does not, the handshake
3882        succeeds, and the client callback fires with an empty byte string.
3883        """
3884        called = []
3885
3886        def ocsp_callback(conn, ocsp_data, ignored):
3887            called.append(ocsp_data)
3888            return True
3889
3890        client = self._client_connection(callback=ocsp_callback, data=None)
3891        server = loopback_server_factory(socket=None)
3892        handshake_in_memory(client, server)
3893
3894        assert len(called) == 1
3895        assert called[0] == b''
3896
3897    def test_client_receives_servers_data(self):
3898        """
3899        The data the server sends in its callback is received by the client.
3900        """
3901        calls = []
3902
3903        def server_callback(*args, **kwargs):
3904            return self.sample_ocsp_data
3905
3906        def client_callback(conn, ocsp_data, ignored):
3907            calls.append(ocsp_data)
3908            return True
3909
3910        client = self._client_connection(callback=client_callback, data=None)
3911        server = self._server_connection(callback=server_callback, data=None)
3912        handshake_in_memory(client, server)
3913
3914        assert len(calls) == 1
3915        assert calls[0] == self.sample_ocsp_data
3916
3917    def test_callbacks_are_invoked_with_connections(self):
3918        """
3919        The first arguments to both callbacks are their respective connections.
3920        """
3921        client_calls = []
3922        server_calls = []
3923
3924        def client_callback(conn, *args, **kwargs):
3925            client_calls.append(conn)
3926            return True
3927
3928        def server_callback(conn, *args, **kwargs):
3929            server_calls.append(conn)
3930            return self.sample_ocsp_data
3931
3932        client = self._client_connection(callback=client_callback, data=None)
3933        server = self._server_connection(callback=server_callback, data=None)
3934        handshake_in_memory(client, server)
3935
3936        assert len(client_calls) == 1
3937        assert len(server_calls) == 1
3938        assert client_calls[0] is client
3939        assert server_calls[0] is server
3940
3941    def test_opaque_data_is_passed_through(self):
3942        """
3943        Both callbacks receive an opaque, user-provided piece of data in their
3944        callbacks as the final argument.
3945        """
3946        calls = []
3947
3948        def server_callback(*args):
3949            calls.append(args)
3950            return self.sample_ocsp_data
3951
3952        def client_callback(*args):
3953            calls.append(args)
3954            return True
3955
3956        sentinel = object()
3957
3958        client = self._client_connection(
3959            callback=client_callback, data=sentinel
3960        )
3961        server = self._server_connection(
3962            callback=server_callback, data=sentinel
3963        )
3964        handshake_in_memory(client, server)
3965
3966        assert len(calls) == 2
3967        assert calls[0][-1] is sentinel
3968        assert calls[1][-1] is sentinel
3969
3970    def test_server_returns_empty_string(self):
3971        """
3972        If the server returns an empty bytestring from its callback, the
3973        client callback is called with the empty bytestring.
3974        """
3975        client_calls = []
3976
3977        def server_callback(*args):
3978            return b''
3979
3980        def client_callback(conn, ocsp_data, ignored):
3981            client_calls.append(ocsp_data)
3982            return True
3983
3984        client = self._client_connection(callback=client_callback, data=None)
3985        server = self._server_connection(callback=server_callback, data=None)
3986        handshake_in_memory(client, server)
3987
3988        assert len(client_calls) == 1
3989        assert client_calls[0] == b''
3990
3991    def test_client_returns_false_terminates_handshake(self):
3992        """
3993        If the client returns False from its callback, the handshake fails.
3994        """
3995        def server_callback(*args):
3996            return self.sample_ocsp_data
3997
3998        def client_callback(*args):
3999            return False
4000
4001        client = self._client_connection(callback=client_callback, data=None)
4002        server = self._server_connection(callback=server_callback, data=None)
4003
4004        with pytest.raises(Error):
4005            handshake_in_memory(client, server)
4006
4007    def test_exceptions_in_client_bubble_up(self):
4008        """
4009        The callbacks thrown in the client callback bubble up to the caller.
4010        """
4011        class SentinelException(Exception):
4012            pass
4013
4014        def server_callback(*args):
4015            return self.sample_ocsp_data
4016
4017        def client_callback(*args):
4018            raise SentinelException()
4019
4020        client = self._client_connection(callback=client_callback, data=None)
4021        server = self._server_connection(callback=server_callback, data=None)
4022
4023        with pytest.raises(SentinelException):
4024            handshake_in_memory(client, server)
4025
4026    def test_exceptions_in_server_bubble_up(self):
4027        """
4028        The callbacks thrown in the server callback bubble up to the caller.
4029        """
4030        class SentinelException(Exception):
4031            pass
4032
4033        def server_callback(*args):
4034            raise SentinelException()
4035
4036        def client_callback(*args):  # pragma: nocover
4037            pytest.fail("Should not be called")
4038
4039        client = self._client_connection(callback=client_callback, data=None)
4040        server = self._server_connection(callback=server_callback, data=None)
4041
4042        with pytest.raises(SentinelException):
4043            handshake_in_memory(client, server)
4044
4045    def test_server_must_return_bytes(self):
4046        """
4047        The server callback must return a bytestring, or a TypeError is thrown.
4048        """
4049        def server_callback(*args):
4050            return self.sample_ocsp_data.decode('ascii')
4051
4052        def client_callback(*args):  # pragma: nocover
4053            pytest.fail("Should not be called")
4054
4055        client = self._client_connection(callback=client_callback, data=None)
4056        server = self._server_connection(callback=server_callback, data=None)
4057
4058        with pytest.raises(TypeError):
4059            handshake_in_memory(client, server)
4060