1# Copyright (C) Jean-Paul Calderone 2# See LICENSE for details. 3 4""" 5Unit tests for :mod:`OpenSSL.SSL`. 6""" 7 8import datetime 9import sys 10import uuid 11 12from gc import collect, get_referrers 13from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN 14from sys import platform, getfilesystemencoding 15from socket import MSG_PEEK, SHUT_RDWR, error, socket 16from os import makedirs 17from os.path import join 18from weakref import ref 19from warnings import simplefilter 20 21import pytest 22 23from pretend import raiser 24 25from six import PY3, text_type 26 27from cryptography import x509 28from cryptography.hazmat.backends import default_backend 29from cryptography.hazmat.primitives import hashes 30from cryptography.hazmat.primitives import serialization 31from cryptography.hazmat.primitives.asymmetric import rsa 32from cryptography.x509.oid import NameOID 33 34 35from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM 36from OpenSSL.crypto import PKey, X509, X509Extension, X509Store 37from OpenSSL.crypto import dump_privatekey, load_privatekey 38from OpenSSL.crypto import dump_certificate, load_certificate 39from OpenSSL.crypto import get_elliptic_curves 40 41from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS 42from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON 43from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN 44from OpenSSL.SSL import ( 45 SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, 46 TLSv1_1_METHOD, TLSv1_2_METHOD) 47from OpenSSL.SSL import OP_SINGLE_DH_USE, OP_NO_SSLv2, OP_NO_SSLv3 48from OpenSSL.SSL import ( 49 VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE, VERIFY_NONE) 50 51from OpenSSL import SSL 52from OpenSSL.SSL import ( 53 SESS_CACHE_OFF, SESS_CACHE_CLIENT, SESS_CACHE_SERVER, SESS_CACHE_BOTH, 54 SESS_CACHE_NO_AUTO_CLEAR, SESS_CACHE_NO_INTERNAL_LOOKUP, 55 SESS_CACHE_NO_INTERNAL_STORE, SESS_CACHE_NO_INTERNAL) 56 57from OpenSSL.SSL import ( 58 Error, SysCallError, WantReadError, WantWriteError, ZeroReturnError) 59from OpenSSL.SSL import ( 60 Context, ContextType, Session, Connection, ConnectionType, SSLeay_version) 61from OpenSSL.SSL import _make_requires 62 63from OpenSSL._util import ffi as _ffi, lib as _lib 64 65from OpenSSL.SSL import ( 66 OP_NO_QUERY_MTU, OP_COOKIE_EXCHANGE, OP_NO_TICKET, OP_NO_COMPRESSION, 67 MODE_RELEASE_BUFFERS) 68 69from OpenSSL.SSL import ( 70 SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, 71 SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, 72 SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, 73 SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, 74 SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE) 75 76try: 77 from OpenSSL.SSL import ( 78 SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE 79 ) 80except ImportError: 81 SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None 82 83from .util import WARNING_TYPE_EXPECTED, NON_ASCII, is_consistent_type 84from .test_crypto import ( 85 cleartextCertificatePEM, cleartextPrivateKeyPEM, 86 client_cert_pem, client_key_pem, server_cert_pem, server_key_pem, 87 root_cert_pem) 88 89 90# openssl dhparam 1024 -out dh-1024.pem (note that 1024 is a small number of 91# bits to use) 92dhparam = """\ 93-----BEGIN DH PARAMETERS----- 94MIGHAoGBALdUMvn+C9MM+y5BWZs11mSeH6HHoEq0UVbzVq7UojC1hbsZUuGukQ3a 95Qh2/pwqb18BZFykrWB0zv/OkLa0kx4cuUgNrUVq1EFheBiX6YqryJ7t2sO09NQiO 96V7H54LmltOT/hEh6QWsJqb6BQgH65bswvV/XkYGja8/T0GzvbaVzAgEC 97-----END DH PARAMETERS----- 98""" 99 100 101skip_if_py3 = pytest.mark.skipif(PY3, reason="Python 2 only") 102 103 104def join_bytes_or_unicode(prefix, suffix): 105 """ 106 Join two path components of either ``bytes`` or ``unicode``. 107 108 The return type is the same as the type of ``prefix``. 109 """ 110 # If the types are the same, nothing special is necessary. 111 if type(prefix) == type(suffix): 112 return join(prefix, suffix) 113 114 # Otherwise, coerce suffix to the type of prefix. 115 if isinstance(prefix, text_type): 116 return join(prefix, suffix.decode(getfilesystemencoding())) 117 else: 118 return join(prefix, suffix.encode(getfilesystemencoding())) 119 120 121def verify_cb(conn, cert, errnum, depth, ok): 122 return ok 123 124 125def socket_pair(): 126 """ 127 Establish and return a pair of network sockets connected to each other. 128 """ 129 # Connect a pair of sockets 130 port = socket() 131 port.bind(('', 0)) 132 port.listen(1) 133 client = socket() 134 client.setblocking(False) 135 client.connect_ex(("127.0.0.1", port.getsockname()[1])) 136 client.setblocking(True) 137 server = port.accept()[0] 138 139 # Let's pass some unencrypted data to make sure our socket connection is 140 # fine. Just one byte, so we don't have to worry about buffers getting 141 # filled up or fragmentation. 142 server.send(b"x") 143 assert client.recv(1024) == b"x" 144 client.send(b"y") 145 assert server.recv(1024) == b"y" 146 147 # Most of our callers want non-blocking sockets, make it easy for them. 148 server.setblocking(False) 149 client.setblocking(False) 150 151 return (server, client) 152 153 154def handshake(client, server): 155 conns = [client, server] 156 while conns: 157 for conn in conns: 158 try: 159 conn.do_handshake() 160 except WantReadError: 161 pass 162 else: 163 conns.remove(conn) 164 165 166def _create_certificate_chain(): 167 """ 168 Construct and return a chain of certificates. 169 170 1. A new self-signed certificate authority certificate (cacert) 171 2. A new intermediate certificate signed by cacert (icert) 172 3. A new server certificate signed by icert (scert) 173 """ 174 caext = X509Extension(b'basicConstraints', False, b'CA:true') 175 176 # Step 1 177 cakey = PKey() 178 cakey.generate_key(TYPE_RSA, 1024) 179 cacert = X509() 180 cacert.get_subject().commonName = "Authority Certificate" 181 cacert.set_issuer(cacert.get_subject()) 182 cacert.set_pubkey(cakey) 183 cacert.set_notBefore(b"20000101000000Z") 184 cacert.set_notAfter(b"20200101000000Z") 185 cacert.add_extensions([caext]) 186 cacert.set_serial_number(0) 187 cacert.sign(cakey, "sha1") 188 189 # Step 2 190 ikey = PKey() 191 ikey.generate_key(TYPE_RSA, 1024) 192 icert = X509() 193 icert.get_subject().commonName = "Intermediate Certificate" 194 icert.set_issuer(cacert.get_subject()) 195 icert.set_pubkey(ikey) 196 icert.set_notBefore(b"20000101000000Z") 197 icert.set_notAfter(b"20200101000000Z") 198 icert.add_extensions([caext]) 199 icert.set_serial_number(0) 200 icert.sign(cakey, "sha1") 201 202 # Step 3 203 skey = PKey() 204 skey.generate_key(TYPE_RSA, 1024) 205 scert = X509() 206 scert.get_subject().commonName = "Server Certificate" 207 scert.set_issuer(icert.get_subject()) 208 scert.set_pubkey(skey) 209 scert.set_notBefore(b"20000101000000Z") 210 scert.set_notAfter(b"20200101000000Z") 211 scert.add_extensions([ 212 X509Extension(b'basicConstraints', True, b'CA:false')]) 213 scert.set_serial_number(0) 214 scert.sign(ikey, "sha1") 215 216 return [(cakey, cacert), (ikey, icert), (skey, scert)] 217 218 219def loopback_client_factory(socket, version=SSLv23_METHOD): 220 client = Connection(Context(version), socket) 221 client.set_connect_state() 222 return client 223 224 225def loopback_server_factory(socket, version=SSLv23_METHOD): 226 ctx = Context(version) 227 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 228 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 229 server = Connection(ctx, socket) 230 server.set_accept_state() 231 return server 232 233 234def loopback(server_factory=None, client_factory=None): 235 """ 236 Create a connected socket pair and force two connected SSL sockets 237 to talk to each other via memory BIOs. 238 """ 239 if server_factory is None: 240 server_factory = loopback_server_factory 241 if client_factory is None: 242 client_factory = loopback_client_factory 243 244 (server, client) = socket_pair() 245 server = server_factory(server) 246 client = client_factory(client) 247 248 handshake(client, server) 249 250 server.setblocking(True) 251 client.setblocking(True) 252 return server, client 253 254 255def interact_in_memory(client_conn, server_conn): 256 """ 257 Try to read application bytes from each of the two `Connection` objects. 258 Copy bytes back and forth between their send/receive buffers for as long 259 as there is anything to copy. When there is nothing more to copy, 260 return `None`. If one of them actually manages to deliver some application 261 bytes, return a two-tuple of the connection from which the bytes were read 262 and the bytes themselves. 263 """ 264 wrote = True 265 while wrote: 266 # Loop until neither side has anything to say 267 wrote = False 268 269 # Copy stuff from each side's send buffer to the other side's 270 # receive buffer. 271 for (read, write) in [(client_conn, server_conn), 272 (server_conn, client_conn)]: 273 274 # Give the side a chance to generate some more bytes, or succeed. 275 try: 276 data = read.recv(2 ** 16) 277 except WantReadError: 278 # It didn't succeed, so we'll hope it generated some output. 279 pass 280 else: 281 # It did succeed, so we'll stop now and let the caller deal 282 # with it. 283 return (read, data) 284 285 while True: 286 # Keep copying as long as there's more stuff there. 287 try: 288 dirty = read.bio_read(4096) 289 except WantReadError: 290 # Okay, nothing more waiting to be sent. Stop 291 # processing this send buffer. 292 break 293 else: 294 # Keep track of the fact that someone generated some 295 # output. 296 wrote = True 297 write.bio_write(dirty) 298 299 300def handshake_in_memory(client_conn, server_conn): 301 """ 302 Perform the TLS handshake between two `Connection` instances connected to 303 each other via memory BIOs. 304 """ 305 client_conn.set_connect_state() 306 server_conn.set_accept_state() 307 308 for conn in [client_conn, server_conn]: 309 try: 310 conn.do_handshake() 311 except WantReadError: 312 pass 313 314 interact_in_memory(client_conn, server_conn) 315 316 317class TestVersion(object): 318 """ 319 Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and 320 `OpenSSL.SSL.OPENSSL_VERSION_NUMBER`. 321 """ 322 def test_OPENSSL_VERSION_NUMBER(self): 323 """ 324 `OPENSSL_VERSION_NUMBER` is an integer with status in the low byte and 325 the patch, fix, minor, and major versions in the nibbles above that. 326 """ 327 assert isinstance(OPENSSL_VERSION_NUMBER, int) 328 329 def test_SSLeay_version(self): 330 """ 331 `SSLeay_version` takes a version type indicator and returns one of a 332 number of version strings based on that indicator. 333 """ 334 versions = {} 335 for t in [SSLEAY_VERSION, SSLEAY_CFLAGS, SSLEAY_BUILT_ON, 336 SSLEAY_PLATFORM, SSLEAY_DIR]: 337 version = SSLeay_version(t) 338 versions[version] = t 339 assert isinstance(version, bytes) 340 assert len(versions) == 5 341 342 343@pytest.fixture 344def ca_file(tmpdir): 345 """ 346 Create a valid PEM file with CA certificates and return the path. 347 """ 348 key = rsa.generate_private_key( 349 public_exponent=65537, 350 key_size=2048, 351 backend=default_backend() 352 ) 353 public_key = key.public_key() 354 355 builder = x509.CertificateBuilder() 356 builder = builder.subject_name(x509.Name([ 357 x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), 358 ])) 359 builder = builder.issuer_name(x509.Name([ 360 x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), 361 ])) 362 one_day = datetime.timedelta(1, 0, 0) 363 builder = builder.not_valid_before(datetime.datetime.today() - one_day) 364 builder = builder.not_valid_after(datetime.datetime.today() + one_day) 365 builder = builder.serial_number(int(uuid.uuid4())) 366 builder = builder.public_key(public_key) 367 builder = builder.add_extension( 368 x509.BasicConstraints(ca=True, path_length=None), critical=True, 369 ) 370 371 certificate = builder.sign( 372 private_key=key, algorithm=hashes.SHA256(), 373 backend=default_backend() 374 ) 375 376 ca_file = tmpdir.join("test.pem") 377 ca_file.write_binary( 378 certificate.public_bytes( 379 encoding=serialization.Encoding.PEM, 380 ) 381 ) 382 383 return str(ca_file).encode("ascii") 384 385 386@pytest.fixture 387def context(): 388 """ 389 A simple TLS 1.0 context. 390 """ 391 return Context(TLSv1_METHOD) 392 393 394class TestContext(object): 395 """ 396 Unit tests for `OpenSSL.SSL.Context`. 397 """ 398 @pytest.mark.parametrize("cipher_string", [ 399 b"hello world:AES128-SHA", 400 u"hello world:AES128-SHA", 401 ]) 402 def test_set_cipher_list(self, context, cipher_string): 403 """ 404 `Context.set_cipher_list` accepts both byte and unicode strings 405 for naming the ciphers which connections created with the context 406 object will be able to choose from. 407 """ 408 context.set_cipher_list(cipher_string) 409 conn = Connection(context, None) 410 411 assert "AES128-SHA" in conn.get_cipher_list() 412 413 @pytest.mark.parametrize("cipher_list,error", [ 414 (object(), TypeError), 415 ("imaginary-cipher", Error), 416 ]) 417 def test_set_cipher_list_wrong_args(self, context, cipher_list, error): 418 """ 419 `Context.set_cipher_list` raises `TypeError` when passed a non-string 420 argument and raises `OpenSSL.SSL.Error` when passed an incorrect cipher 421 list string. 422 """ 423 with pytest.raises(error): 424 context.set_cipher_list(cipher_list) 425 426 def test_load_client_ca(self, context, ca_file): 427 """ 428 `Context.load_client_ca` works as far as we can tell. 429 """ 430 context.load_client_ca(ca_file) 431 432 def test_load_client_ca_invalid(self, context, tmpdir): 433 """ 434 `Context.load_client_ca` raises an Error if the ca file is invalid. 435 """ 436 ca_file = tmpdir.join("test.pem") 437 ca_file.write("") 438 439 with pytest.raises(Error) as e: 440 context.load_client_ca(str(ca_file).encode("ascii")) 441 442 assert "PEM routines" == e.value.args[0][0][0] 443 444 def test_load_client_ca_unicode(self, context, ca_file): 445 """ 446 Passing the path as unicode raises a warning but works. 447 """ 448 pytest.deprecated_call( 449 context.load_client_ca, ca_file.decode("ascii") 450 ) 451 452 def test_set_session_id(self, context): 453 """ 454 `Context.set_session_id` works as far as we can tell. 455 """ 456 context.set_session_id(b"abc") 457 458 def test_set_session_id_fail(self, context): 459 """ 460 `Context.set_session_id` errors are propagated. 461 """ 462 with pytest.raises(Error) as e: 463 context.set_session_id(b"abc" * 1000) 464 465 assert [ 466 ("SSL routines", 467 "SSL_CTX_set_session_id_context", 468 "ssl session id context too long") 469 ] == e.value.args[0] 470 471 def test_set_session_id_unicode(self, context): 472 """ 473 `Context.set_session_id` raises a warning if a unicode string is 474 passed. 475 """ 476 pytest.deprecated_call(context.set_session_id, u"abc") 477 478 def test_method(self): 479 """ 480 `Context` can be instantiated with one of `SSLv2_METHOD`, 481 `SSLv3_METHOD`, `SSLv23_METHOD`, `TLSv1_METHOD`, `TLSv1_1_METHOD`, 482 or `TLSv1_2_METHOD`. 483 """ 484 methods = [SSLv23_METHOD, TLSv1_METHOD] 485 for meth in methods: 486 Context(meth) 487 488 maybe = [SSLv2_METHOD, SSLv3_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD] 489 for meth in maybe: 490 try: 491 Context(meth) 492 except (Error, ValueError): 493 # Some versions of OpenSSL have SSLv2 / TLSv1.1 / TLSv1.2, some 494 # don't. Difficult to say in advance. 495 pass 496 497 with pytest.raises(TypeError): 498 Context("") 499 with pytest.raises(ValueError): 500 Context(10) 501 502 @skip_if_py3 503 def test_method_long(self): 504 """ 505 On Python 2 `Context` accepts values of type `long` as well as `int`. 506 """ 507 Context(long(TLSv1_METHOD)) 508 509 def test_type(self): 510 """ 511 `Context` and `ContextType` refer to the same type object and can 512 be used to create instances of that type. 513 """ 514 assert Context is ContextType 515 assert is_consistent_type(Context, 'Context', TLSv1_METHOD) 516 517 def test_use_privatekey(self): 518 """ 519 `Context.use_privatekey` takes an `OpenSSL.crypto.PKey` instance. 520 """ 521 key = PKey() 522 key.generate_key(TYPE_RSA, 512) 523 ctx = Context(TLSv1_METHOD) 524 ctx.use_privatekey(key) 525 with pytest.raises(TypeError): 526 ctx.use_privatekey("") 527 528 def test_use_privatekey_file_missing(self, tmpfile): 529 """ 530 `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` when passed 531 the name of a file which does not exist. 532 """ 533 ctx = Context(TLSv1_METHOD) 534 with pytest.raises(Error): 535 ctx.use_privatekey_file(tmpfile) 536 537 def _use_privatekey_file_test(self, pemfile, filetype): 538 """ 539 Verify that calling ``Context.use_privatekey_file`` with the given 540 arguments does not raise an exception. 541 """ 542 key = PKey() 543 key.generate_key(TYPE_RSA, 512) 544 545 with open(pemfile, "wt") as pem: 546 pem.write( 547 dump_privatekey(FILETYPE_PEM, key).decode("ascii") 548 ) 549 550 ctx = Context(TLSv1_METHOD) 551 ctx.use_privatekey_file(pemfile, filetype) 552 553 @pytest.mark.parametrize('filetype', [object(), "", None, 1.0]) 554 def test_wrong_privatekey_file_wrong_args(self, tmpfile, filetype): 555 """ 556 `Context.use_privatekey_file` raises `TypeError` when called with 557 a `filetype` which is not a valid file encoding. 558 """ 559 ctx = Context(TLSv1_METHOD) 560 with pytest.raises(TypeError): 561 ctx.use_privatekey_file(tmpfile, filetype) 562 563 def test_use_privatekey_file_bytes(self, tmpfile): 564 """ 565 A private key can be specified from a file by passing a ``bytes`` 566 instance giving the file name to ``Context.use_privatekey_file``. 567 """ 568 self._use_privatekey_file_test( 569 tmpfile + NON_ASCII.encode(getfilesystemencoding()), 570 FILETYPE_PEM, 571 ) 572 573 def test_use_privatekey_file_unicode(self, tmpfile): 574 """ 575 A private key can be specified from a file by passing a ``unicode`` 576 instance giving the file name to ``Context.use_privatekey_file``. 577 """ 578 self._use_privatekey_file_test( 579 tmpfile.decode(getfilesystemencoding()) + NON_ASCII, 580 FILETYPE_PEM, 581 ) 582 583 @skip_if_py3 584 def test_use_privatekey_file_long(self, tmpfile): 585 """ 586 On Python 2 `Context.use_privatekey_file` accepts a filetype of 587 type `long` as well as `int`. 588 """ 589 self._use_privatekey_file_test(tmpfile, long(FILETYPE_PEM)) 590 591 def test_use_certificate_wrong_args(self): 592 """ 593 `Context.use_certificate_wrong_args` raises `TypeError` when not passed 594 exactly one `OpenSSL.crypto.X509` instance as an argument. 595 """ 596 ctx = Context(TLSv1_METHOD) 597 with pytest.raises(TypeError): 598 ctx.use_certificate("hello, world") 599 600 def test_use_certificate_uninitialized(self): 601 """ 602 `Context.use_certificate` raises `OpenSSL.SSL.Error` when passed a 603 `OpenSSL.crypto.X509` instance which has not been initialized 604 (ie, which does not actually have any certificate data). 605 """ 606 ctx = Context(TLSv1_METHOD) 607 with pytest.raises(Error): 608 ctx.use_certificate(X509()) 609 610 def test_use_certificate(self): 611 """ 612 `Context.use_certificate` sets the certificate which will be 613 used to identify connections created using the context. 614 """ 615 # TODO 616 # Hard to assert anything. But we could set a privatekey then ask 617 # OpenSSL if the cert and key agree using check_privatekey. Then as 618 # long as check_privatekey works right we're good... 619 ctx = Context(TLSv1_METHOD) 620 ctx.use_certificate( 621 load_certificate(FILETYPE_PEM, cleartextCertificatePEM) 622 ) 623 624 def test_use_certificate_file_wrong_args(self): 625 """ 626 `Context.use_certificate_file` raises `TypeError` if the first 627 argument is not a byte string or the second argument is not an integer. 628 """ 629 ctx = Context(TLSv1_METHOD) 630 with pytest.raises(TypeError): 631 ctx.use_certificate_file(object(), FILETYPE_PEM) 632 with pytest.raises(TypeError): 633 ctx.use_certificate_file(b"somefile", object()) 634 with pytest.raises(TypeError): 635 ctx.use_certificate_file(object(), FILETYPE_PEM) 636 637 def test_use_certificate_file_missing(self, tmpfile): 638 """ 639 `Context.use_certificate_file` raises `OpenSSL.SSL.Error` if passed 640 the name of a file which does not exist. 641 """ 642 ctx = Context(TLSv1_METHOD) 643 with pytest.raises(Error): 644 ctx.use_certificate_file(tmpfile) 645 646 def _use_certificate_file_test(self, certificate_file): 647 """ 648 Verify that calling ``Context.use_certificate_file`` with the given 649 filename doesn't raise an exception. 650 """ 651 # TODO 652 # Hard to assert anything. But we could set a privatekey then ask 653 # OpenSSL if the cert and key agree using check_privatekey. Then as 654 # long as check_privatekey works right we're good... 655 with open(certificate_file, "wb") as pem_file: 656 pem_file.write(cleartextCertificatePEM) 657 658 ctx = Context(TLSv1_METHOD) 659 ctx.use_certificate_file(certificate_file) 660 661 def test_use_certificate_file_bytes(self, tmpfile): 662 """ 663 `Context.use_certificate_file` sets the certificate (given as a 664 `bytes` filename) which will be used to identify connections created 665 using the context. 666 """ 667 filename = tmpfile + NON_ASCII.encode(getfilesystemencoding()) 668 self._use_certificate_file_test(filename) 669 670 def test_use_certificate_file_unicode(self, tmpfile): 671 """ 672 `Context.use_certificate_file` sets the certificate (given as a 673 `bytes` filename) which will be used to identify connections created 674 using the context. 675 """ 676 filename = tmpfile.decode(getfilesystemencoding()) + NON_ASCII 677 self._use_certificate_file_test(filename) 678 679 @skip_if_py3 680 def test_use_certificate_file_long(self, tmpfile): 681 """ 682 On Python 2 `Context.use_certificate_file` accepts a 683 filetype of type `long` as well as `int`. 684 """ 685 pem_filename = tmpfile 686 with open(pem_filename, "wb") as pem_file: 687 pem_file.write(cleartextCertificatePEM) 688 689 ctx = Context(TLSv1_METHOD) 690 ctx.use_certificate_file(pem_filename, long(FILETYPE_PEM)) 691 692 def test_check_privatekey_valid(self): 693 """ 694 `Context.check_privatekey` returns `None` if the `Context` instance 695 has been configured to use a matched key and certificate pair. 696 """ 697 key = load_privatekey(FILETYPE_PEM, client_key_pem) 698 cert = load_certificate(FILETYPE_PEM, client_cert_pem) 699 context = Context(TLSv1_METHOD) 700 context.use_privatekey(key) 701 context.use_certificate(cert) 702 assert None is context.check_privatekey() 703 704 def test_check_privatekey_invalid(self): 705 """ 706 `Context.check_privatekey` raises `Error` if the `Context` instance 707 has been configured to use a key and certificate pair which don't 708 relate to each other. 709 """ 710 key = load_privatekey(FILETYPE_PEM, client_key_pem) 711 cert = load_certificate(FILETYPE_PEM, server_cert_pem) 712 context = Context(TLSv1_METHOD) 713 context.use_privatekey(key) 714 context.use_certificate(cert) 715 with pytest.raises(Error): 716 context.check_privatekey() 717 718 def test_app_data(self): 719 """ 720 `Context.set_app_data` stores an object for later retrieval 721 using `Context.get_app_data`. 722 """ 723 app_data = object() 724 context = Context(TLSv1_METHOD) 725 context.set_app_data(app_data) 726 assert context.get_app_data() is app_data 727 728 def test_set_options_wrong_args(self): 729 """ 730 `Context.set_options` raises `TypeError` if called with 731 a non-`int` argument. 732 """ 733 context = Context(TLSv1_METHOD) 734 with pytest.raises(TypeError): 735 context.set_options(None) 736 737 def test_set_options(self): 738 """ 739 `Context.set_options` returns the new options value. 740 """ 741 context = Context(TLSv1_METHOD) 742 options = context.set_options(OP_NO_SSLv2) 743 assert options & OP_NO_SSLv2 == OP_NO_SSLv2 744 745 @skip_if_py3 746 def test_set_options_long(self): 747 """ 748 On Python 2 `Context.set_options` accepts values of type 749 `long` as well as `int`. 750 """ 751 context = Context(TLSv1_METHOD) 752 options = context.set_options(long(OP_NO_SSLv2)) 753 assert options & OP_NO_SSLv2 == OP_NO_SSLv2 754 755 def test_set_mode_wrong_args(self): 756 """ 757 `Context.set_mode` raises `TypeError` if called with 758 a non-`int` argument. 759 """ 760 context = Context(TLSv1_METHOD) 761 with pytest.raises(TypeError): 762 context.set_mode(None) 763 764 def test_set_mode(self): 765 """ 766 `Context.set_mode` accepts a mode bitvector and returns the 767 newly set mode. 768 """ 769 context = Context(TLSv1_METHOD) 770 assert MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS) 771 772 @skip_if_py3 773 def test_set_mode_long(self): 774 """ 775 On Python 2 `Context.set_mode` accepts values of type `long` as well 776 as `int`. 777 """ 778 context = Context(TLSv1_METHOD) 779 mode = context.set_mode(long(MODE_RELEASE_BUFFERS)) 780 assert MODE_RELEASE_BUFFERS & mode 781 782 def test_set_timeout_wrong_args(self): 783 """ 784 `Context.set_timeout` raises `TypeError` if called with 785 a non-`int` argument. 786 """ 787 context = Context(TLSv1_METHOD) 788 with pytest.raises(TypeError): 789 context.set_timeout(None) 790 791 def test_timeout(self): 792 """ 793 `Context.set_timeout` sets the session timeout for all connections 794 created using the context object. `Context.get_timeout` retrieves 795 this value. 796 """ 797 context = Context(TLSv1_METHOD) 798 context.set_timeout(1234) 799 assert context.get_timeout() == 1234 800 801 @skip_if_py3 802 def test_timeout_long(self): 803 """ 804 On Python 2 `Context.set_timeout` accepts values of type `long` as 805 well as int. 806 """ 807 context = Context(TLSv1_METHOD) 808 context.set_timeout(long(1234)) 809 assert context.get_timeout() == 1234 810 811 def test_set_verify_depth_wrong_args(self): 812 """ 813 `Context.set_verify_depth` raises `TypeError` if called with a 814 non-`int` argument. 815 """ 816 context = Context(TLSv1_METHOD) 817 with pytest.raises(TypeError): 818 context.set_verify_depth(None) 819 820 def test_verify_depth(self): 821 """ 822 `Context.set_verify_depth` sets the number of certificates in 823 a chain to follow before giving up. The value can be retrieved with 824 `Context.get_verify_depth`. 825 """ 826 context = Context(TLSv1_METHOD) 827 context.set_verify_depth(11) 828 assert context.get_verify_depth() == 11 829 830 @skip_if_py3 831 def test_verify_depth_long(self): 832 """ 833 On Python 2 `Context.set_verify_depth` accepts values of type `long` 834 as well as int. 835 """ 836 context = Context(TLSv1_METHOD) 837 context.set_verify_depth(long(11)) 838 assert context.get_verify_depth() == 11 839 840 def _write_encrypted_pem(self, passphrase, tmpfile): 841 """ 842 Write a new private key out to a new file, encrypted using the given 843 passphrase. Return the path to the new file. 844 """ 845 key = PKey() 846 key.generate_key(TYPE_RSA, 512) 847 pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase) 848 with open(tmpfile, 'w') as fObj: 849 fObj.write(pem.decode('ascii')) 850 return tmpfile 851 852 def test_set_passwd_cb_wrong_args(self): 853 """ 854 `Context.set_passwd_cb` raises `TypeError` if called with a 855 non-callable first argument. 856 """ 857 context = Context(TLSv1_METHOD) 858 with pytest.raises(TypeError): 859 context.set_passwd_cb(None) 860 861 def test_set_passwd_cb(self, tmpfile): 862 """ 863 `Context.set_passwd_cb` accepts a callable which will be invoked when 864 a private key is loaded from an encrypted PEM. 865 """ 866 passphrase = b"foobar" 867 pemFile = self._write_encrypted_pem(passphrase, tmpfile) 868 calledWith = [] 869 870 def passphraseCallback(maxlen, verify, extra): 871 calledWith.append((maxlen, verify, extra)) 872 return passphrase 873 context = Context(TLSv1_METHOD) 874 context.set_passwd_cb(passphraseCallback) 875 context.use_privatekey_file(pemFile) 876 assert len(calledWith) == 1 877 assert isinstance(calledWith[0][0], int) 878 assert isinstance(calledWith[0][1], int) 879 assert calledWith[0][2] is None 880 881 def test_passwd_callback_exception(self, tmpfile): 882 """ 883 `Context.use_privatekey_file` propagates any exception raised 884 by the passphrase callback. 885 """ 886 pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) 887 888 def passphraseCallback(maxlen, verify, extra): 889 raise RuntimeError("Sorry, I am a fail.") 890 891 context = Context(TLSv1_METHOD) 892 context.set_passwd_cb(passphraseCallback) 893 with pytest.raises(RuntimeError): 894 context.use_privatekey_file(pemFile) 895 896 def test_passwd_callback_false(self, tmpfile): 897 """ 898 `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the 899 passphrase callback returns a false value. 900 """ 901 pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) 902 903 def passphraseCallback(maxlen, verify, extra): 904 return b"" 905 906 context = Context(TLSv1_METHOD) 907 context.set_passwd_cb(passphraseCallback) 908 with pytest.raises(Error): 909 context.use_privatekey_file(pemFile) 910 911 def test_passwd_callback_non_string(self, tmpfile): 912 """ 913 `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the 914 passphrase callback returns a true non-string value. 915 """ 916 pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) 917 918 def passphraseCallback(maxlen, verify, extra): 919 return 10 920 921 context = Context(TLSv1_METHOD) 922 context.set_passwd_cb(passphraseCallback) 923 # TODO: Surely this is the wrong error? 924 with pytest.raises(ValueError): 925 context.use_privatekey_file(pemFile) 926 927 def test_passwd_callback_too_long(self, tmpfile): 928 """ 929 If the passphrase returned by the passphrase callback returns a string 930 longer than the indicated maximum length, it is truncated. 931 """ 932 # A priori knowledge! 933 passphrase = b"x" * 1024 934 pemFile = self._write_encrypted_pem(passphrase, tmpfile) 935 936 def passphraseCallback(maxlen, verify, extra): 937 assert maxlen == 1024 938 return passphrase + b"y" 939 940 context = Context(TLSv1_METHOD) 941 context.set_passwd_cb(passphraseCallback) 942 # This shall succeed because the truncated result is the correct 943 # passphrase. 944 context.use_privatekey_file(pemFile) 945 946 def test_set_info_callback(self): 947 """ 948 `Context.set_info_callback` accepts a callable which will be 949 invoked when certain information about an SSL connection is available. 950 """ 951 (server, client) = socket_pair() 952 953 clientSSL = Connection(Context(TLSv1_METHOD), client) 954 clientSSL.set_connect_state() 955 956 called = [] 957 958 def info(conn, where, ret): 959 called.append((conn, where, ret)) 960 context = Context(TLSv1_METHOD) 961 context.set_info_callback(info) 962 context.use_certificate( 963 load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) 964 context.use_privatekey( 965 load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) 966 967 serverSSL = Connection(context, server) 968 serverSSL.set_accept_state() 969 970 handshake(clientSSL, serverSSL) 971 972 # The callback must always be called with a Connection instance as the 973 # first argument. It would probably be better to split this into 974 # separate tests for client and server side info callbacks so we could 975 # assert it is called with the right Connection instance. It would 976 # also be good to assert *something* about `where` and `ret`. 977 notConnections = [ 978 conn for (conn, where, ret) in called 979 if not isinstance(conn, Connection)] 980 assert [] == notConnections, ( 981 "Some info callback arguments were not Connection instances.") 982 983 def _load_verify_locations_test(self, *args): 984 """ 985 Create a client context which will verify the peer certificate and call 986 its `load_verify_locations` method with the given arguments. 987 Then connect it to a server and ensure that the handshake succeeds. 988 """ 989 (server, client) = socket_pair() 990 991 clientContext = Context(TLSv1_METHOD) 992 clientContext.load_verify_locations(*args) 993 # Require that the server certificate verify properly or the 994 # connection will fail. 995 clientContext.set_verify( 996 VERIFY_PEER, 997 lambda conn, cert, errno, depth, preverify_ok: preverify_ok) 998 999 clientSSL = Connection(clientContext, client) 1000 clientSSL.set_connect_state() 1001 1002 serverContext = Context(TLSv1_METHOD) 1003 serverContext.use_certificate( 1004 load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) 1005 serverContext.use_privatekey( 1006 load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) 1007 1008 serverSSL = Connection(serverContext, server) 1009 serverSSL.set_accept_state() 1010 1011 # Without load_verify_locations above, the handshake 1012 # will fail: 1013 # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE', 1014 # 'certificate verify failed')] 1015 handshake(clientSSL, serverSSL) 1016 1017 cert = clientSSL.get_peer_certificate() 1018 assert cert.get_subject().CN == 'Testing Root CA' 1019 1020 def _load_verify_cafile(self, cafile): 1021 """ 1022 Verify that if path to a file containing a certificate is passed to 1023 `Context.load_verify_locations` for the ``cafile`` parameter, that 1024 certificate is used as a trust root for the purposes of verifying 1025 connections created using that `Context`. 1026 """ 1027 with open(cafile, 'w') as fObj: 1028 fObj.write(cleartextCertificatePEM.decode('ascii')) 1029 1030 self._load_verify_locations_test(cafile) 1031 1032 def test_load_verify_bytes_cafile(self, tmpfile): 1033 """ 1034 `Context.load_verify_locations` accepts a file name as a `bytes` 1035 instance and uses the certificates within for verification purposes. 1036 """ 1037 cafile = tmpfile + NON_ASCII.encode(getfilesystemencoding()) 1038 self._load_verify_cafile(cafile) 1039 1040 def test_load_verify_unicode_cafile(self, tmpfile): 1041 """ 1042 `Context.load_verify_locations` accepts a file name as a `unicode` 1043 instance and uses the certificates within for verification purposes. 1044 """ 1045 self._load_verify_cafile( 1046 tmpfile.decode(getfilesystemencoding()) + NON_ASCII 1047 ) 1048 1049 def test_load_verify_invalid_file(self, tmpfile): 1050 """ 1051 `Context.load_verify_locations` raises `Error` when passed a 1052 non-existent cafile. 1053 """ 1054 clientContext = Context(TLSv1_METHOD) 1055 with pytest.raises(Error): 1056 clientContext.load_verify_locations(tmpfile) 1057 1058 def _load_verify_directory_locations_capath(self, capath): 1059 """ 1060 Verify that if path to a directory containing certificate files is 1061 passed to ``Context.load_verify_locations`` for the ``capath`` 1062 parameter, those certificates are used as trust roots for the purposes 1063 of verifying connections created using that ``Context``. 1064 """ 1065 makedirs(capath) 1066 # Hash values computed manually with c_rehash to avoid depending on 1067 # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other 1068 # from OpenSSL 1.0.0. 1069 for name in [b'c7adac82.0', b'c3705638.0']: 1070 cafile = join_bytes_or_unicode(capath, name) 1071 with open(cafile, 'w') as fObj: 1072 fObj.write(cleartextCertificatePEM.decode('ascii')) 1073 1074 self._load_verify_locations_test(None, capath) 1075 1076 def test_load_verify_directory_bytes_capath(self, tmpfile): 1077 """ 1078 `Context.load_verify_locations` accepts a directory name as a `bytes` 1079 instance and uses the certificates within for verification purposes. 1080 """ 1081 self._load_verify_directory_locations_capath( 1082 tmpfile + NON_ASCII.encode(getfilesystemencoding()) 1083 ) 1084 1085 def test_load_verify_directory_unicode_capath(self, tmpfile): 1086 """ 1087 `Context.load_verify_locations` accepts a directory name as a `unicode` 1088 instance and uses the certificates within for verification purposes. 1089 """ 1090 self._load_verify_directory_locations_capath( 1091 tmpfile.decode(getfilesystemencoding()) + NON_ASCII 1092 ) 1093 1094 def test_load_verify_locations_wrong_args(self): 1095 """ 1096 `Context.load_verify_locations` raises `TypeError` if with non-`str` 1097 arguments. 1098 """ 1099 context = Context(TLSv1_METHOD) 1100 with pytest.raises(TypeError): 1101 context.load_verify_locations(object()) 1102 with pytest.raises(TypeError): 1103 context.load_verify_locations(object(), object()) 1104 1105 @pytest.mark.skipif( 1106 not platform.startswith("linux"), 1107 reason="Loading fallback paths is a linux-specific behavior to " 1108 "accommodate pyca/cryptography manylinux1 wheels" 1109 ) 1110 def test_fallback_default_verify_paths(self, monkeypatch): 1111 """ 1112 Test that we load certificates successfully on linux from the fallback 1113 path. To do this we set the _CRYPTOGRAPHY_MANYLINUX1_CA_FILE and 1114 _CRYPTOGRAPHY_MANYLINUX1_CA_DIR vars to be equal to whatever the 1115 current OpenSSL default is and we disable 1116 SSL_CTX_SET_default_verify_paths so that it can't find certs unless 1117 it loads via fallback. 1118 """ 1119 context = Context(TLSv1_METHOD) 1120 monkeypatch.setattr( 1121 _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 1122 ) 1123 monkeypatch.setattr( 1124 SSL, 1125 "_CRYPTOGRAPHY_MANYLINUX1_CA_FILE", 1126 _ffi.string(_lib.X509_get_default_cert_file()) 1127 ) 1128 monkeypatch.setattr( 1129 SSL, 1130 "_CRYPTOGRAPHY_MANYLINUX1_CA_DIR", 1131 _ffi.string(_lib.X509_get_default_cert_dir()) 1132 ) 1133 context.set_default_verify_paths() 1134 store = context.get_cert_store() 1135 sk_obj = _lib.X509_STORE_get0_objects(store._store) 1136 assert sk_obj != _ffi.NULL 1137 num = _lib.sk_X509_OBJECT_num(sk_obj) 1138 assert num != 0 1139 1140 def test_check_env_vars(self, monkeypatch): 1141 """ 1142 Test that we return True/False appropriately if the env vars are set. 1143 """ 1144 context = Context(TLSv1_METHOD) 1145 dir_var = "CUSTOM_DIR_VAR" 1146 file_var = "CUSTOM_FILE_VAR" 1147 assert context._check_env_vars_set(dir_var, file_var) is False 1148 monkeypatch.setenv(dir_var, "value") 1149 monkeypatch.setenv(file_var, "value") 1150 assert context._check_env_vars_set(dir_var, file_var) is True 1151 assert context._check_env_vars_set(dir_var, file_var) is True 1152 1153 def test_verify_no_fallback_if_env_vars_set(self, monkeypatch): 1154 """ 1155 Test that we don't use the fallback path if env vars are set. 1156 """ 1157 context = Context(TLSv1_METHOD) 1158 monkeypatch.setattr( 1159 _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 1160 ) 1161 dir_env_var = _ffi.string( 1162 _lib.X509_get_default_cert_dir_env() 1163 ).decode("ascii") 1164 file_env_var = _ffi.string( 1165 _lib.X509_get_default_cert_file_env() 1166 ).decode("ascii") 1167 monkeypatch.setenv(dir_env_var, "value") 1168 monkeypatch.setenv(file_env_var, "value") 1169 context.set_default_verify_paths() 1170 1171 monkeypatch.setattr( 1172 context, 1173 "_fallback_default_verify_paths", 1174 raiser(SystemError) 1175 ) 1176 context.set_default_verify_paths() 1177 1178 @pytest.mark.skipif( 1179 platform == "win32", 1180 reason="set_default_verify_paths appears not to work on Windows. " 1181 "See LP#404343 and LP#404344." 1182 ) 1183 def test_set_default_verify_paths(self): 1184 """ 1185 `Context.set_default_verify_paths` causes the platform-specific CA 1186 certificate locations to be used for verification purposes. 1187 """ 1188 # Testing this requires a server with a certificate signed by one 1189 # of the CAs in the platform CA location. Getting one of those 1190 # costs money. Fortunately (or unfortunately, depending on your 1191 # perspective), it's easy to think of a public server on the 1192 # internet which has such a certificate. Connecting to the network 1193 # in a unit test is bad, but it's the only way I can think of to 1194 # really test this. -exarkun 1195 context = Context(SSLv23_METHOD) 1196 context.set_default_verify_paths() 1197 context.set_verify( 1198 VERIFY_PEER, 1199 lambda conn, cert, errno, depth, preverify_ok: preverify_ok) 1200 1201 client = socket() 1202 client.connect(("encrypted.google.com", 443)) 1203 clientSSL = Connection(context, client) 1204 clientSSL.set_connect_state() 1205 clientSSL.set_tlsext_host_name(b"encrypted.google.com") 1206 clientSSL.do_handshake() 1207 clientSSL.send(b"GET / HTTP/1.0\r\n\r\n") 1208 assert clientSSL.recv(1024) 1209 1210 def test_fallback_path_is_not_file_or_dir(self): 1211 """ 1212 Test that when passed empty arrays or paths that do not exist no 1213 errors are raised. 1214 """ 1215 context = Context(TLSv1_METHOD) 1216 context._fallback_default_verify_paths([], []) 1217 context._fallback_default_verify_paths( 1218 ["/not/a/file"], ["/not/a/dir"] 1219 ) 1220 1221 def test_add_extra_chain_cert_invalid_cert(self): 1222 """ 1223 `Context.add_extra_chain_cert` raises `TypeError` if called with an 1224 object which is not an instance of `X509`. 1225 """ 1226 context = Context(TLSv1_METHOD) 1227 with pytest.raises(TypeError): 1228 context.add_extra_chain_cert(object()) 1229 1230 def _handshake_test(self, serverContext, clientContext): 1231 """ 1232 Verify that a client and server created with the given contexts can 1233 successfully handshake and communicate. 1234 """ 1235 serverSocket, clientSocket = socket_pair() 1236 1237 server = Connection(serverContext, serverSocket) 1238 server.set_accept_state() 1239 1240 client = Connection(clientContext, clientSocket) 1241 client.set_connect_state() 1242 1243 # Make them talk to each other. 1244 # interact_in_memory(client, server) 1245 for _ in range(3): 1246 for s in [client, server]: 1247 try: 1248 s.do_handshake() 1249 except WantReadError: 1250 pass 1251 1252 def test_set_verify_callback_connection_argument(self): 1253 """ 1254 The first argument passed to the verify callback is the 1255 `Connection` instance for which verification is taking place. 1256 """ 1257 serverContext = Context(TLSv1_METHOD) 1258 serverContext.use_privatekey( 1259 load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) 1260 serverContext.use_certificate( 1261 load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) 1262 serverConnection = Connection(serverContext, None) 1263 1264 class VerifyCallback(object): 1265 def callback(self, connection, *args): 1266 self.connection = connection 1267 return 1 1268 1269 verify = VerifyCallback() 1270 clientContext = Context(TLSv1_METHOD) 1271 clientContext.set_verify(VERIFY_PEER, verify.callback) 1272 clientConnection = Connection(clientContext, None) 1273 clientConnection.set_connect_state() 1274 1275 handshake_in_memory(clientConnection, serverConnection) 1276 1277 assert verify.connection is clientConnection 1278 1279 def test_x509_in_verify_works(self): 1280 """ 1281 We had a bug where the X509 cert instantiated in the callback wrapper 1282 didn't __init__ so it was missing objects needed when calling 1283 get_subject. This test sets up a handshake where we call get_subject 1284 on the cert provided to the verify callback. 1285 """ 1286 serverContext = Context(TLSv1_METHOD) 1287 serverContext.use_privatekey( 1288 load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) 1289 serverContext.use_certificate( 1290 load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) 1291 serverConnection = Connection(serverContext, None) 1292 1293 def verify_cb_get_subject(conn, cert, errnum, depth, ok): 1294 assert cert.get_subject() 1295 return 1 1296 1297 clientContext = Context(TLSv1_METHOD) 1298 clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject) 1299 clientConnection = Connection(clientContext, None) 1300 clientConnection.set_connect_state() 1301 1302 handshake_in_memory(clientConnection, serverConnection) 1303 1304 def test_set_verify_callback_exception(self): 1305 """ 1306 If the verify callback passed to `Context.set_verify` raises an 1307 exception, verification fails and the exception is propagated to the 1308 caller of `Connection.do_handshake`. 1309 """ 1310 serverContext = Context(TLSv1_2_METHOD) 1311 serverContext.use_privatekey( 1312 load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)) 1313 serverContext.use_certificate( 1314 load_certificate(FILETYPE_PEM, cleartextCertificatePEM)) 1315 1316 clientContext = Context(TLSv1_2_METHOD) 1317 1318 def verify_callback(*args): 1319 raise Exception("silly verify failure") 1320 clientContext.set_verify(VERIFY_PEER, verify_callback) 1321 1322 with pytest.raises(Exception) as exc: 1323 self._handshake_test(serverContext, clientContext) 1324 1325 assert "silly verify failure" == str(exc.value) 1326 1327 def test_add_extra_chain_cert(self, tmpdir): 1328 """ 1329 `Context.add_extra_chain_cert` accepts an `X509` 1330 instance to add to the certificate chain. 1331 1332 See `_create_certificate_chain` for the details of the 1333 certificate chain tested. 1334 1335 The chain is tested by starting a server with scert and connecting 1336 to it with a client which trusts cacert and requires verification to 1337 succeed. 1338 """ 1339 chain = _create_certificate_chain() 1340 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 1341 1342 # Dump the CA certificate to a file because that's the only way to load 1343 # it as a trusted CA in the client context. 1344 for cert, name in [(cacert, 'ca.pem'), 1345 (icert, 'i.pem'), 1346 (scert, 's.pem')]: 1347 with tmpdir.join(name).open('w') as f: 1348 f.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii')) 1349 1350 for key, name in [(cakey, 'ca.key'), 1351 (ikey, 'i.key'), 1352 (skey, 's.key')]: 1353 with tmpdir.join(name).open('w') as f: 1354 f.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii')) 1355 1356 # Create the server context 1357 serverContext = Context(TLSv1_METHOD) 1358 serverContext.use_privatekey(skey) 1359 serverContext.use_certificate(scert) 1360 # The client already has cacert, we only need to give them icert. 1361 serverContext.add_extra_chain_cert(icert) 1362 1363 # Create the client 1364 clientContext = Context(TLSv1_METHOD) 1365 clientContext.set_verify( 1366 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) 1367 clientContext.load_verify_locations(str(tmpdir.join("ca.pem"))) 1368 1369 # Try it out. 1370 self._handshake_test(serverContext, clientContext) 1371 1372 def _use_certificate_chain_file_test(self, certdir): 1373 """ 1374 Verify that `Context.use_certificate_chain_file` reads a 1375 certificate chain from a specified file. 1376 1377 The chain is tested by starting a server with scert and connecting to 1378 it with a client which trusts cacert and requires verification to 1379 succeed. 1380 """ 1381 chain = _create_certificate_chain() 1382 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 1383 1384 makedirs(certdir) 1385 1386 chainFile = join_bytes_or_unicode(certdir, "chain.pem") 1387 caFile = join_bytes_or_unicode(certdir, "ca.pem") 1388 1389 # Write out the chain file. 1390 with open(chainFile, 'wb') as fObj: 1391 # Most specific to least general. 1392 fObj.write(dump_certificate(FILETYPE_PEM, scert)) 1393 fObj.write(dump_certificate(FILETYPE_PEM, icert)) 1394 fObj.write(dump_certificate(FILETYPE_PEM, cacert)) 1395 1396 with open(caFile, 'w') as fObj: 1397 fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii')) 1398 1399 serverContext = Context(TLSv1_METHOD) 1400 serverContext.use_certificate_chain_file(chainFile) 1401 serverContext.use_privatekey(skey) 1402 1403 clientContext = Context(TLSv1_METHOD) 1404 clientContext.set_verify( 1405 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb) 1406 clientContext.load_verify_locations(caFile) 1407 1408 self._handshake_test(serverContext, clientContext) 1409 1410 def test_use_certificate_chain_file_bytes(self, tmpfile): 1411 """ 1412 ``Context.use_certificate_chain_file`` accepts the name of a file (as 1413 an instance of ``bytes``) to specify additional certificates to use to 1414 construct and verify a trust chain. 1415 """ 1416 self._use_certificate_chain_file_test( 1417 tmpfile + NON_ASCII.encode(getfilesystemencoding()) 1418 ) 1419 1420 def test_use_certificate_chain_file_unicode(self, tmpfile): 1421 """ 1422 ``Context.use_certificate_chain_file`` accepts the name of a file (as 1423 an instance of ``unicode``) to specify additional certificates to use 1424 to construct and verify a trust chain. 1425 """ 1426 self._use_certificate_chain_file_test( 1427 tmpfile.decode(getfilesystemencoding()) + NON_ASCII 1428 ) 1429 1430 def test_use_certificate_chain_file_wrong_args(self): 1431 """ 1432 `Context.use_certificate_chain_file` raises `TypeError` if passed a 1433 non-byte string single argument. 1434 """ 1435 context = Context(TLSv1_METHOD) 1436 with pytest.raises(TypeError): 1437 context.use_certificate_chain_file(object()) 1438 1439 def test_use_certificate_chain_file_missing_file(self, tmpfile): 1440 """ 1441 `Context.use_certificate_chain_file` raises `OpenSSL.SSL.Error` when 1442 passed a bad chain file name (for example, the name of a file which 1443 does not exist). 1444 """ 1445 context = Context(TLSv1_METHOD) 1446 with pytest.raises(Error): 1447 context.use_certificate_chain_file(tmpfile) 1448 1449 def test_set_verify_mode(self): 1450 """ 1451 `Context.get_verify_mode` returns the verify mode flags previously 1452 passed to `Context.set_verify`. 1453 """ 1454 context = Context(TLSv1_METHOD) 1455 assert context.get_verify_mode() == 0 1456 context.set_verify( 1457 VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None) 1458 assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) 1459 1460 @skip_if_py3 1461 def test_set_verify_mode_long(self): 1462 """ 1463 On Python 2 `Context.set_verify_mode` accepts values of type `long` 1464 as well as `int`. 1465 """ 1466 context = Context(TLSv1_METHOD) 1467 assert context.get_verify_mode() == 0 1468 context.set_verify( 1469 long(VERIFY_PEER | VERIFY_CLIENT_ONCE), lambda *args: None 1470 ) # pragma: nocover 1471 assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) 1472 1473 @pytest.mark.parametrize('mode', [None, 1.0, object(), 'mode']) 1474 def test_set_verify_wrong_mode_arg(self, mode): 1475 """ 1476 `Context.set_verify` raises `TypeError` if the first argument is 1477 not an integer. 1478 """ 1479 context = Context(TLSv1_METHOD) 1480 with pytest.raises(TypeError): 1481 context.set_verify(mode=mode, callback=lambda *args: None) 1482 1483 @pytest.mark.parametrize('callback', [None, 1.0, 'mode', ('foo', 'bar')]) 1484 def test_set_verify_wrong_callable_arg(self, callback): 1485 """ 1486 `Context.set_verify` raises `TypeError` if the second argument 1487 is not callable. 1488 """ 1489 context = Context(TLSv1_METHOD) 1490 with pytest.raises(TypeError): 1491 context.set_verify(mode=VERIFY_PEER, callback=callback) 1492 1493 def test_load_tmp_dh_wrong_args(self): 1494 """ 1495 `Context.load_tmp_dh` raises `TypeError` if called with a 1496 non-`str` argument. 1497 """ 1498 context = Context(TLSv1_METHOD) 1499 with pytest.raises(TypeError): 1500 context.load_tmp_dh(object()) 1501 1502 def test_load_tmp_dh_missing_file(self): 1503 """ 1504 `Context.load_tmp_dh` raises `OpenSSL.SSL.Error` if the 1505 specified file does not exist. 1506 """ 1507 context = Context(TLSv1_METHOD) 1508 with pytest.raises(Error): 1509 context.load_tmp_dh(b"hello") 1510 1511 def _load_tmp_dh_test(self, dhfilename): 1512 """ 1513 Verify that calling ``Context.load_tmp_dh`` with the given filename 1514 does not raise an exception. 1515 """ 1516 context = Context(TLSv1_METHOD) 1517 with open(dhfilename, "w") as dhfile: 1518 dhfile.write(dhparam) 1519 1520 context.load_tmp_dh(dhfilename) 1521 # XXX What should I assert here? -exarkun 1522 1523 def test_load_tmp_dh_bytes(self, tmpfile): 1524 """ 1525 `Context.load_tmp_dh` loads Diffie-Hellman parameters from the 1526 specified file (given as ``bytes``). 1527 """ 1528 self._load_tmp_dh_test( 1529 tmpfile + NON_ASCII.encode(getfilesystemencoding()), 1530 ) 1531 1532 def test_load_tmp_dh_unicode(self, tmpfile): 1533 """ 1534 `Context.load_tmp_dh` loads Diffie-Hellman parameters from the 1535 specified file (given as ``unicode``). 1536 """ 1537 self._load_tmp_dh_test( 1538 tmpfile.decode(getfilesystemencoding()) + NON_ASCII, 1539 ) 1540 1541 def test_set_tmp_ecdh(self): 1542 """ 1543 `Context.set_tmp_ecdh` sets the elliptic curve for Diffie-Hellman to 1544 the specified curve. 1545 """ 1546 context = Context(TLSv1_METHOD) 1547 for curve in get_elliptic_curves(): 1548 if curve.name.startswith(u"Oakley-"): 1549 # Setting Oakley-EC2N-4 and Oakley-EC2N-3 adds 1550 # ('bignum routines', 'BN_mod_inverse', 'no inverse') to the 1551 # error queue on OpenSSL 1.0.2. 1552 continue 1553 # The only easily "assertable" thing is that it does not raise an 1554 # exception. 1555 context.set_tmp_ecdh(curve) 1556 1557 def test_set_session_cache_mode_wrong_args(self): 1558 """ 1559 `Context.set_session_cache_mode` raises `TypeError` if called with 1560 a non-integer argument. 1561 called with other than one integer argument. 1562 """ 1563 context = Context(TLSv1_METHOD) 1564 with pytest.raises(TypeError): 1565 context.set_session_cache_mode(object()) 1566 1567 def test_session_cache_mode(self): 1568 """ 1569 `Context.set_session_cache_mode` specifies how sessions are cached. 1570 The setting can be retrieved via `Context.get_session_cache_mode`. 1571 """ 1572 context = Context(TLSv1_METHOD) 1573 context.set_session_cache_mode(SESS_CACHE_OFF) 1574 off = context.set_session_cache_mode(SESS_CACHE_BOTH) 1575 assert SESS_CACHE_OFF == off 1576 assert SESS_CACHE_BOTH == context.get_session_cache_mode() 1577 1578 @skip_if_py3 1579 def test_session_cache_mode_long(self): 1580 """ 1581 On Python 2 `Context.set_session_cache_mode` accepts values 1582 of type `long` as well as `int`. 1583 """ 1584 context = Context(TLSv1_METHOD) 1585 context.set_session_cache_mode(long(SESS_CACHE_BOTH)) 1586 assert SESS_CACHE_BOTH == context.get_session_cache_mode() 1587 1588 def test_get_cert_store(self): 1589 """ 1590 `Context.get_cert_store` returns a `X509Store` instance. 1591 """ 1592 context = Context(TLSv1_METHOD) 1593 store = context.get_cert_store() 1594 assert isinstance(store, X509Store) 1595 1596 def test_set_tlsext_use_srtp_not_bytes(self): 1597 """ 1598 `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material. 1599 1600 It raises a TypeError if the list of profiles is not a byte string. 1601 """ 1602 context = Context(TLSv1_METHOD) 1603 with pytest.raises(TypeError): 1604 context.set_tlsext_use_srtp(text_type('SRTP_AES128_CM_SHA1_80')) 1605 1606 def test_set_tlsext_use_srtp_invalid_profile(self): 1607 """ 1608 `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material. 1609 1610 It raises an Error if the call to OpenSSL fails. 1611 """ 1612 context = Context(TLSv1_METHOD) 1613 with pytest.raises(Error): 1614 context.set_tlsext_use_srtp(b'SRTP_BOGUS') 1615 1616 def test_set_tlsext_use_srtp_valid(self): 1617 """ 1618 `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material. 1619 1620 It does not return anything. 1621 """ 1622 context = Context(TLSv1_METHOD) 1623 assert context.set_tlsext_use_srtp(b'SRTP_AES128_CM_SHA1_80') is None 1624 1625 1626class TestServerNameCallback(object): 1627 """ 1628 Tests for `Context.set_tlsext_servername_callback` and its 1629 interaction with `Connection`. 1630 """ 1631 def test_old_callback_forgotten(self): 1632 """ 1633 If `Context.set_tlsext_servername_callback` is used to specify 1634 a new callback, the one it replaces is dereferenced. 1635 """ 1636 def callback(connection): # pragma: no cover 1637 pass 1638 1639 def replacement(connection): # pragma: no cover 1640 pass 1641 1642 context = Context(TLSv1_METHOD) 1643 context.set_tlsext_servername_callback(callback) 1644 1645 tracker = ref(callback) 1646 del callback 1647 1648 context.set_tlsext_servername_callback(replacement) 1649 1650 # One run of the garbage collector happens to work on CPython. PyPy 1651 # doesn't collect the underlying object until a second run for whatever 1652 # reason. That's fine, it still demonstrates our code has properly 1653 # dropped the reference. 1654 collect() 1655 collect() 1656 1657 callback = tracker() 1658 if callback is not None: 1659 referrers = get_referrers(callback) 1660 if len(referrers) > 1: # pragma: nocover 1661 pytest.fail("Some references remain: %r" % (referrers,)) 1662 1663 def test_no_servername(self): 1664 """ 1665 When a client specifies no server name, the callback passed to 1666 `Context.set_tlsext_servername_callback` is invoked and the 1667 result of `Connection.get_servername` is `None`. 1668 """ 1669 args = [] 1670 1671 def servername(conn): 1672 args.append((conn, conn.get_servername())) 1673 context = Context(TLSv1_METHOD) 1674 context.set_tlsext_servername_callback(servername) 1675 1676 # Lose our reference to it. The Context is responsible for keeping it 1677 # alive now. 1678 del servername 1679 collect() 1680 1681 # Necessary to actually accept the connection 1682 context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 1683 context.use_certificate( 1684 load_certificate(FILETYPE_PEM, server_cert_pem)) 1685 1686 # Do a little connection to trigger the logic 1687 server = Connection(context, None) 1688 server.set_accept_state() 1689 1690 client = Connection(Context(TLSv1_METHOD), None) 1691 client.set_connect_state() 1692 1693 interact_in_memory(server, client) 1694 1695 assert args == [(server, None)] 1696 1697 def test_servername(self): 1698 """ 1699 When a client specifies a server name in its hello message, the 1700 callback passed to `Contexts.set_tlsext_servername_callback` is 1701 invoked and the result of `Connection.get_servername` is that 1702 server name. 1703 """ 1704 args = [] 1705 1706 def servername(conn): 1707 args.append((conn, conn.get_servername())) 1708 context = Context(TLSv1_METHOD) 1709 context.set_tlsext_servername_callback(servername) 1710 1711 # Necessary to actually accept the connection 1712 context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 1713 context.use_certificate( 1714 load_certificate(FILETYPE_PEM, server_cert_pem)) 1715 1716 # Do a little connection to trigger the logic 1717 server = Connection(context, None) 1718 server.set_accept_state() 1719 1720 client = Connection(Context(TLSv1_METHOD), None) 1721 client.set_connect_state() 1722 client.set_tlsext_host_name(b"foo1.example.com") 1723 1724 interact_in_memory(server, client) 1725 1726 assert args == [(server, b"foo1.example.com")] 1727 1728 1729class TestNextProtoNegotiation(object): 1730 """ 1731 Test for Next Protocol Negotiation in PyOpenSSL. 1732 """ 1733 def test_npn_success(self): 1734 """ 1735 Tests that clients and servers that agree on the negotiated next 1736 protocol can correct establish a connection, and that the agreed 1737 protocol is reported by the connections. 1738 """ 1739 advertise_args = [] 1740 select_args = [] 1741 1742 def advertise(conn): 1743 advertise_args.append((conn,)) 1744 return [b'http/1.1', b'spdy/2'] 1745 1746 def select(conn, options): 1747 select_args.append((conn, options)) 1748 return b'spdy/2' 1749 1750 server_context = Context(TLSv1_METHOD) 1751 server_context.set_npn_advertise_callback(advertise) 1752 1753 client_context = Context(TLSv1_METHOD) 1754 client_context.set_npn_select_callback(select) 1755 1756 # Necessary to actually accept the connection 1757 server_context.use_privatekey( 1758 load_privatekey(FILETYPE_PEM, server_key_pem)) 1759 server_context.use_certificate( 1760 load_certificate(FILETYPE_PEM, server_cert_pem)) 1761 1762 # Do a little connection to trigger the logic 1763 server = Connection(server_context, None) 1764 server.set_accept_state() 1765 1766 client = Connection(client_context, None) 1767 client.set_connect_state() 1768 1769 interact_in_memory(server, client) 1770 1771 assert advertise_args == [(server,)] 1772 assert select_args == [(client, [b'http/1.1', b'spdy/2'])] 1773 1774 assert server.get_next_proto_negotiated() == b'spdy/2' 1775 assert client.get_next_proto_negotiated() == b'spdy/2' 1776 1777 def test_npn_client_fail(self): 1778 """ 1779 Tests that when clients and servers cannot agree on what protocol 1780 to use next that the TLS connection does not get established. 1781 """ 1782 advertise_args = [] 1783 select_args = [] 1784 1785 def advertise(conn): 1786 advertise_args.append((conn,)) 1787 return [b'http/1.1', b'spdy/2'] 1788 1789 def select(conn, options): 1790 select_args.append((conn, options)) 1791 return b'' 1792 1793 server_context = Context(TLSv1_METHOD) 1794 server_context.set_npn_advertise_callback(advertise) 1795 1796 client_context = Context(TLSv1_METHOD) 1797 client_context.set_npn_select_callback(select) 1798 1799 # Necessary to actually accept the connection 1800 server_context.use_privatekey( 1801 load_privatekey(FILETYPE_PEM, server_key_pem)) 1802 server_context.use_certificate( 1803 load_certificate(FILETYPE_PEM, server_cert_pem)) 1804 1805 # Do a little connection to trigger the logic 1806 server = Connection(server_context, None) 1807 server.set_accept_state() 1808 1809 client = Connection(client_context, None) 1810 client.set_connect_state() 1811 1812 # If the client doesn't return anything, the connection will fail. 1813 with pytest.raises(Error): 1814 interact_in_memory(server, client) 1815 1816 assert advertise_args == [(server,)] 1817 assert select_args == [(client, [b'http/1.1', b'spdy/2'])] 1818 1819 def test_npn_select_error(self): 1820 """ 1821 Test that we can handle exceptions in the select callback. If 1822 select fails it should be fatal to the connection. 1823 """ 1824 advertise_args = [] 1825 1826 def advertise(conn): 1827 advertise_args.append((conn,)) 1828 return [b'http/1.1', b'spdy/2'] 1829 1830 def select(conn, options): 1831 raise TypeError 1832 1833 server_context = Context(TLSv1_METHOD) 1834 server_context.set_npn_advertise_callback(advertise) 1835 1836 client_context = Context(TLSv1_METHOD) 1837 client_context.set_npn_select_callback(select) 1838 1839 # Necessary to actually accept the connection 1840 server_context.use_privatekey( 1841 load_privatekey(FILETYPE_PEM, server_key_pem)) 1842 server_context.use_certificate( 1843 load_certificate(FILETYPE_PEM, server_cert_pem)) 1844 1845 # Do a little connection to trigger the logic 1846 server = Connection(server_context, None) 1847 server.set_accept_state() 1848 1849 client = Connection(client_context, None) 1850 client.set_connect_state() 1851 1852 # If the callback throws an exception it should be raised here. 1853 with pytest.raises(TypeError): 1854 interact_in_memory(server, client) 1855 assert advertise_args == [(server,), ] 1856 1857 def test_npn_advertise_error(self): 1858 """ 1859 Test that we can handle exceptions in the advertise callback. If 1860 advertise fails no NPN is advertised to the client. 1861 """ 1862 select_args = [] 1863 1864 def advertise(conn): 1865 raise TypeError 1866 1867 def select(conn, options): # pragma: nocover 1868 """ 1869 Assert later that no args are actually appended. 1870 """ 1871 select_args.append((conn, options)) 1872 return b'' 1873 1874 server_context = Context(TLSv1_METHOD) 1875 server_context.set_npn_advertise_callback(advertise) 1876 1877 client_context = Context(TLSv1_METHOD) 1878 client_context.set_npn_select_callback(select) 1879 1880 # Necessary to actually accept the connection 1881 server_context.use_privatekey( 1882 load_privatekey(FILETYPE_PEM, server_key_pem)) 1883 server_context.use_certificate( 1884 load_certificate(FILETYPE_PEM, server_cert_pem)) 1885 1886 # Do a little connection to trigger the logic 1887 server = Connection(server_context, None) 1888 server.set_accept_state() 1889 1890 client = Connection(client_context, None) 1891 client.set_connect_state() 1892 1893 # If the client doesn't return anything, the connection will fail. 1894 with pytest.raises(TypeError): 1895 interact_in_memory(server, client) 1896 assert select_args == [] 1897 1898 1899class TestApplicationLayerProtoNegotiation(object): 1900 """ 1901 Tests for ALPN in PyOpenSSL. 1902 """ 1903 # Skip tests on versions that don't support ALPN. 1904 if _lib.Cryptography_HAS_ALPN: 1905 1906 def test_alpn_success(self): 1907 """ 1908 Clients and servers that agree on the negotiated ALPN protocol can 1909 correct establish a connection, and the agreed protocol is reported 1910 by the connections. 1911 """ 1912 select_args = [] 1913 1914 def select(conn, options): 1915 select_args.append((conn, options)) 1916 return b'spdy/2' 1917 1918 client_context = Context(TLSv1_METHOD) 1919 client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) 1920 1921 server_context = Context(TLSv1_METHOD) 1922 server_context.set_alpn_select_callback(select) 1923 1924 # Necessary to actually accept the connection 1925 server_context.use_privatekey( 1926 load_privatekey(FILETYPE_PEM, server_key_pem)) 1927 server_context.use_certificate( 1928 load_certificate(FILETYPE_PEM, server_cert_pem)) 1929 1930 # Do a little connection to trigger the logic 1931 server = Connection(server_context, None) 1932 server.set_accept_state() 1933 1934 client = Connection(client_context, None) 1935 client.set_connect_state() 1936 1937 interact_in_memory(server, client) 1938 1939 assert select_args == [(server, [b'http/1.1', b'spdy/2'])] 1940 1941 assert server.get_alpn_proto_negotiated() == b'spdy/2' 1942 assert client.get_alpn_proto_negotiated() == b'spdy/2' 1943 1944 def test_alpn_set_on_connection(self): 1945 """ 1946 The same as test_alpn_success, but setting the ALPN protocols on 1947 the connection rather than the context. 1948 """ 1949 select_args = [] 1950 1951 def select(conn, options): 1952 select_args.append((conn, options)) 1953 return b'spdy/2' 1954 1955 # Setup the client context but don't set any ALPN protocols. 1956 client_context = Context(TLSv1_METHOD) 1957 1958 server_context = Context(TLSv1_METHOD) 1959 server_context.set_alpn_select_callback(select) 1960 1961 # Necessary to actually accept the connection 1962 server_context.use_privatekey( 1963 load_privatekey(FILETYPE_PEM, server_key_pem)) 1964 server_context.use_certificate( 1965 load_certificate(FILETYPE_PEM, server_cert_pem)) 1966 1967 # Do a little connection to trigger the logic 1968 server = Connection(server_context, None) 1969 server.set_accept_state() 1970 1971 # Set the ALPN protocols on the client connection. 1972 client = Connection(client_context, None) 1973 client.set_alpn_protos([b'http/1.1', b'spdy/2']) 1974 client.set_connect_state() 1975 1976 interact_in_memory(server, client) 1977 1978 assert select_args == [(server, [b'http/1.1', b'spdy/2'])] 1979 1980 assert server.get_alpn_proto_negotiated() == b'spdy/2' 1981 assert client.get_alpn_proto_negotiated() == b'spdy/2' 1982 1983 def test_alpn_server_fail(self): 1984 """ 1985 When clients and servers cannot agree on what protocol to use next 1986 the TLS connection does not get established. 1987 """ 1988 select_args = [] 1989 1990 def select(conn, options): 1991 select_args.append((conn, options)) 1992 return b'' 1993 1994 client_context = Context(TLSv1_METHOD) 1995 client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) 1996 1997 server_context = Context(TLSv1_METHOD) 1998 server_context.set_alpn_select_callback(select) 1999 2000 # Necessary to actually accept the connection 2001 server_context.use_privatekey( 2002 load_privatekey(FILETYPE_PEM, server_key_pem)) 2003 server_context.use_certificate( 2004 load_certificate(FILETYPE_PEM, server_cert_pem)) 2005 2006 # Do a little connection to trigger the logic 2007 server = Connection(server_context, None) 2008 server.set_accept_state() 2009 2010 client = Connection(client_context, None) 2011 client.set_connect_state() 2012 2013 # If the client doesn't return anything, the connection will fail. 2014 with pytest.raises(Error): 2015 interact_in_memory(server, client) 2016 2017 assert select_args == [(server, [b'http/1.1', b'spdy/2'])] 2018 2019 def test_alpn_no_server(self): 2020 """ 2021 When clients and servers cannot agree on what protocol to use next 2022 because the server doesn't offer ALPN, no protocol is negotiated. 2023 """ 2024 client_context = Context(TLSv1_METHOD) 2025 client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) 2026 2027 server_context = Context(TLSv1_METHOD) 2028 2029 # Necessary to actually accept the connection 2030 server_context.use_privatekey( 2031 load_privatekey(FILETYPE_PEM, server_key_pem)) 2032 server_context.use_certificate( 2033 load_certificate(FILETYPE_PEM, server_cert_pem)) 2034 2035 # Do a little connection to trigger the logic 2036 server = Connection(server_context, None) 2037 server.set_accept_state() 2038 2039 client = Connection(client_context, None) 2040 client.set_connect_state() 2041 2042 # Do the dance. 2043 interact_in_memory(server, client) 2044 2045 assert client.get_alpn_proto_negotiated() == b'' 2046 2047 def test_alpn_callback_exception(self): 2048 """ 2049 We can handle exceptions in the ALPN select callback. 2050 """ 2051 select_args = [] 2052 2053 def select(conn, options): 2054 select_args.append((conn, options)) 2055 raise TypeError() 2056 2057 client_context = Context(TLSv1_METHOD) 2058 client_context.set_alpn_protos([b'http/1.1', b'spdy/2']) 2059 2060 server_context = Context(TLSv1_METHOD) 2061 server_context.set_alpn_select_callback(select) 2062 2063 # Necessary to actually accept the connection 2064 server_context.use_privatekey( 2065 load_privatekey(FILETYPE_PEM, server_key_pem)) 2066 server_context.use_certificate( 2067 load_certificate(FILETYPE_PEM, server_cert_pem)) 2068 2069 # Do a little connection to trigger the logic 2070 server = Connection(server_context, None) 2071 server.set_accept_state() 2072 2073 client = Connection(client_context, None) 2074 client.set_connect_state() 2075 2076 with pytest.raises(TypeError): 2077 interact_in_memory(server, client) 2078 assert select_args == [(server, [b'http/1.1', b'spdy/2'])] 2079 2080 else: 2081 # No ALPN. 2082 def test_alpn_not_implemented(self): 2083 """ 2084 If ALPN is not in OpenSSL, we should raise NotImplementedError. 2085 """ 2086 # Test the context methods first. 2087 context = Context(TLSv1_METHOD) 2088 with pytest.raises(NotImplementedError): 2089 context.set_alpn_protos(None) 2090 with pytest.raises(NotImplementedError): 2091 context.set_alpn_select_callback(None) 2092 2093 # Now test a connection. 2094 conn = Connection(context) 2095 with pytest.raises(NotImplementedError): 2096 conn.set_alpn_protos(None) 2097 2098 2099class TestSession(object): 2100 """ 2101 Unit tests for :py:obj:`OpenSSL.SSL.Session`. 2102 """ 2103 def test_construction(self): 2104 """ 2105 :py:class:`Session` can be constructed with no arguments, creating 2106 a new instance of that type. 2107 """ 2108 new_session = Session() 2109 assert isinstance(new_session, Session) 2110 2111 2112class TestConnection(object): 2113 """ 2114 Unit tests for `OpenSSL.SSL.Connection`. 2115 """ 2116 # XXX get_peer_certificate -> None 2117 # XXX sock_shutdown 2118 # XXX master_key -> TypeError 2119 # XXX server_random -> TypeError 2120 # XXX connect -> TypeError 2121 # XXX connect_ex -> TypeError 2122 # XXX set_connect_state -> TypeError 2123 # XXX set_accept_state -> TypeError 2124 # XXX do_handshake -> TypeError 2125 # XXX bio_read -> TypeError 2126 # XXX recv -> TypeError 2127 # XXX send -> TypeError 2128 # XXX bio_write -> TypeError 2129 2130 def test_type(self): 2131 """ 2132 `Connection` and `ConnectionType` refer to the same type object and 2133 can be used to create instances of that type. 2134 """ 2135 assert Connection is ConnectionType 2136 ctx = Context(TLSv1_METHOD) 2137 assert is_consistent_type(Connection, 'Connection', ctx, None) 2138 2139 @pytest.mark.parametrize('bad_context', [object(), 'context', None, 1]) 2140 def test_wrong_args(self, bad_context): 2141 """ 2142 `Connection.__init__` raises `TypeError` if called with a non-`Context` 2143 instance argument. 2144 """ 2145 with pytest.raises(TypeError): 2146 Connection(bad_context) 2147 2148 def test_get_context(self): 2149 """ 2150 `Connection.get_context` returns the `Context` instance used to 2151 construct the `Connection` instance. 2152 """ 2153 context = Context(TLSv1_METHOD) 2154 connection = Connection(context, None) 2155 assert connection.get_context() is context 2156 2157 def test_set_context_wrong_args(self): 2158 """ 2159 `Connection.set_context` raises `TypeError` if called with a 2160 non-`Context` instance argument. 2161 """ 2162 ctx = Context(TLSv1_METHOD) 2163 connection = Connection(ctx, None) 2164 with pytest.raises(TypeError): 2165 connection.set_context(object()) 2166 with pytest.raises(TypeError): 2167 connection.set_context("hello") 2168 with pytest.raises(TypeError): 2169 connection.set_context(1) 2170 assert ctx is connection.get_context() 2171 2172 def test_set_context(self): 2173 """ 2174 `Connection.set_context` specifies a new `Context` instance to be 2175 used for the connection. 2176 """ 2177 original = Context(SSLv23_METHOD) 2178 replacement = Context(TLSv1_METHOD) 2179 connection = Connection(original, None) 2180 connection.set_context(replacement) 2181 assert replacement is connection.get_context() 2182 # Lose our references to the contexts, just in case the Connection 2183 # isn't properly managing its own contributions to their reference 2184 # counts. 2185 del original, replacement 2186 collect() 2187 2188 def test_set_tlsext_host_name_wrong_args(self): 2189 """ 2190 If `Connection.set_tlsext_host_name` is called with a non-byte string 2191 argument or a byte string with an embedded NUL, `TypeError` is raised. 2192 """ 2193 conn = Connection(Context(TLSv1_METHOD), None) 2194 with pytest.raises(TypeError): 2195 conn.set_tlsext_host_name(object()) 2196 with pytest.raises(TypeError): 2197 conn.set_tlsext_host_name(b"with\0null") 2198 2199 if PY3: 2200 # On Python 3.x, don't accidentally implicitly convert from text. 2201 with pytest.raises(TypeError): 2202 conn.set_tlsext_host_name(b"example.com".decode("ascii")) 2203 2204 def test_pending(self): 2205 """ 2206 `Connection.pending` returns the number of bytes available for 2207 immediate read. 2208 """ 2209 connection = Connection(Context(TLSv1_METHOD), None) 2210 assert connection.pending() == 0 2211 2212 def test_peek(self): 2213 """ 2214 `Connection.recv` peeks into the connection if `socket.MSG_PEEK` is 2215 passed. 2216 """ 2217 server, client = loopback() 2218 server.send(b'xy') 2219 assert client.recv(2, MSG_PEEK) == b'xy' 2220 assert client.recv(2, MSG_PEEK) == b'xy' 2221 assert client.recv(2) == b'xy' 2222 2223 def test_connect_wrong_args(self): 2224 """ 2225 `Connection.connect` raises `TypeError` if called with a non-address 2226 argument. 2227 """ 2228 connection = Connection(Context(TLSv1_METHOD), socket()) 2229 with pytest.raises(TypeError): 2230 connection.connect(None) 2231 2232 def test_connect_refused(self): 2233 """ 2234 `Connection.connect` raises `socket.error` if the underlying socket 2235 connect method raises it. 2236 """ 2237 client = socket() 2238 context = Context(TLSv1_METHOD) 2239 clientSSL = Connection(context, client) 2240 # pytest.raises here doesn't work because of a bug in py.test on Python 2241 # 2.6: https://github.com/pytest-dev/pytest/issues/988 2242 try: 2243 clientSSL.connect(("127.0.0.1", 1)) 2244 except error as e: 2245 exc = e 2246 assert exc.args[0] == ECONNREFUSED 2247 2248 def test_connect(self): 2249 """ 2250 `Connection.connect` establishes a connection to the specified address. 2251 """ 2252 port = socket() 2253 port.bind(('', 0)) 2254 port.listen(3) 2255 2256 clientSSL = Connection(Context(TLSv1_METHOD), socket()) 2257 clientSSL.connect(('127.0.0.1', port.getsockname()[1])) 2258 # XXX An assertion? Or something? 2259 2260 @pytest.mark.skipif( 2261 platform == "darwin", 2262 reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4" 2263 ) 2264 def test_connect_ex(self): 2265 """ 2266 If there is a connection error, `Connection.connect_ex` returns the 2267 errno instead of raising an exception. 2268 """ 2269 port = socket() 2270 port.bind(('', 0)) 2271 port.listen(3) 2272 2273 clientSSL = Connection(Context(TLSv1_METHOD), socket()) 2274 clientSSL.setblocking(False) 2275 result = clientSSL.connect_ex(port.getsockname()) 2276 expected = (EINPROGRESS, EWOULDBLOCK) 2277 assert result in expected 2278 2279 def test_accept(self): 2280 """ 2281 `Connection.accept` accepts a pending connection attempt and returns a 2282 tuple of a new `Connection` (the accepted client) and the address the 2283 connection originated from. 2284 """ 2285 ctx = Context(TLSv1_METHOD) 2286 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 2287 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 2288 port = socket() 2289 portSSL = Connection(ctx, port) 2290 portSSL.bind(('', 0)) 2291 portSSL.listen(3) 2292 2293 clientSSL = Connection(Context(TLSv1_METHOD), socket()) 2294 2295 # Calling portSSL.getsockname() here to get the server IP address 2296 # sounds great, but frequently fails on Windows. 2297 clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1])) 2298 2299 serverSSL, address = portSSL.accept() 2300 2301 assert isinstance(serverSSL, Connection) 2302 assert serverSSL.get_context() is ctx 2303 assert address == clientSSL.getsockname() 2304 2305 def test_shutdown_wrong_args(self): 2306 """ 2307 `Connection.set_shutdown` raises `TypeError` if called with arguments 2308 other than integers. 2309 """ 2310 connection = Connection(Context(TLSv1_METHOD), None) 2311 with pytest.raises(TypeError): 2312 connection.set_shutdown(None) 2313 2314 def test_shutdown(self): 2315 """ 2316 `Connection.shutdown` performs an SSL-level connection shutdown. 2317 """ 2318 server, client = loopback() 2319 assert not server.shutdown() 2320 assert server.get_shutdown() == SENT_SHUTDOWN 2321 with pytest.raises(ZeroReturnError): 2322 client.recv(1024) 2323 assert client.get_shutdown() == RECEIVED_SHUTDOWN 2324 client.shutdown() 2325 assert client.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN) 2326 with pytest.raises(ZeroReturnError): 2327 server.recv(1024) 2328 assert server.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN) 2329 2330 def test_shutdown_closed(self): 2331 """ 2332 If the underlying socket is closed, `Connection.shutdown` propagates 2333 the write error from the low level write call. 2334 """ 2335 server, client = loopback() 2336 server.sock_shutdown(2) 2337 with pytest.raises(SysCallError) as exc: 2338 server.shutdown() 2339 if platform == "win32": 2340 assert exc.value.args[0] == ESHUTDOWN 2341 else: 2342 assert exc.value.args[0] == EPIPE 2343 2344 def test_shutdown_truncated(self): 2345 """ 2346 If the underlying connection is truncated, `Connection.shutdown` 2347 raises an `Error`. 2348 """ 2349 server_ctx = Context(TLSv1_METHOD) 2350 client_ctx = Context(TLSv1_METHOD) 2351 server_ctx.use_privatekey( 2352 load_privatekey(FILETYPE_PEM, server_key_pem)) 2353 server_ctx.use_certificate( 2354 load_certificate(FILETYPE_PEM, server_cert_pem)) 2355 server = Connection(server_ctx, None) 2356 client = Connection(client_ctx, None) 2357 handshake_in_memory(client, server) 2358 assert not server.shutdown() 2359 with pytest.raises(WantReadError): 2360 server.shutdown() 2361 server.bio_shutdown() 2362 with pytest.raises(Error): 2363 server.shutdown() 2364 2365 def test_set_shutdown(self): 2366 """ 2367 `Connection.set_shutdown` sets the state of the SSL connection 2368 shutdown process. 2369 """ 2370 connection = Connection(Context(TLSv1_METHOD), socket()) 2371 connection.set_shutdown(RECEIVED_SHUTDOWN) 2372 assert connection.get_shutdown() == RECEIVED_SHUTDOWN 2373 2374 @skip_if_py3 2375 def test_set_shutdown_long(self): 2376 """ 2377 On Python 2 `Connection.set_shutdown` accepts an argument 2378 of type `long` as well as `int`. 2379 """ 2380 connection = Connection(Context(TLSv1_METHOD), socket()) 2381 connection.set_shutdown(long(RECEIVED_SHUTDOWN)) 2382 assert connection.get_shutdown() == RECEIVED_SHUTDOWN 2383 2384 def test_state_string(self): 2385 """ 2386 `Connection.state_string` verbosely describes the current state of 2387 the `Connection`. 2388 """ 2389 server, client = socket_pair() 2390 server = loopback_server_factory(server) 2391 client = loopback_client_factory(client) 2392 2393 assert server.get_state_string() in [ 2394 b"before/accept initialization", b"before SSL initialization" 2395 ] 2396 assert client.get_state_string() in [ 2397 b"before/connect initialization", b"before SSL initialization" 2398 ] 2399 2400 def test_app_data(self): 2401 """ 2402 Any object can be set as app data by passing it to 2403 `Connection.set_app_data` and later retrieved with 2404 `Connection.get_app_data`. 2405 """ 2406 conn = Connection(Context(TLSv1_METHOD), None) 2407 assert None is conn.get_app_data() 2408 app_data = object() 2409 conn.set_app_data(app_data) 2410 assert conn.get_app_data() is app_data 2411 2412 def test_makefile(self): 2413 """ 2414 `Connection.makefile` is not implemented and calling that 2415 method raises `NotImplementedError`. 2416 """ 2417 conn = Connection(Context(TLSv1_METHOD), None) 2418 with pytest.raises(NotImplementedError): 2419 conn.makefile() 2420 2421 def test_get_certificate(self): 2422 """ 2423 `Connection.get_certificate` returns the local certificate. 2424 """ 2425 chain = _create_certificate_chain() 2426 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 2427 2428 context = Context(TLSv1_METHOD) 2429 context.use_certificate(scert) 2430 client = Connection(context, None) 2431 cert = client.get_certificate() 2432 assert cert is not None 2433 assert "Server Certificate" == cert.get_subject().CN 2434 2435 def test_get_certificate_none(self): 2436 """ 2437 `Connection.get_certificate` returns the local certificate. 2438 2439 If there is no certificate, it returns None. 2440 """ 2441 context = Context(TLSv1_METHOD) 2442 client = Connection(context, None) 2443 cert = client.get_certificate() 2444 assert cert is None 2445 2446 def test_get_peer_cert_chain(self): 2447 """ 2448 `Connection.get_peer_cert_chain` returns a list of certificates 2449 which the connected server returned for the certification verification. 2450 """ 2451 chain = _create_certificate_chain() 2452 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 2453 2454 serverContext = Context(TLSv1_METHOD) 2455 serverContext.use_privatekey(skey) 2456 serverContext.use_certificate(scert) 2457 serverContext.add_extra_chain_cert(icert) 2458 serverContext.add_extra_chain_cert(cacert) 2459 server = Connection(serverContext, None) 2460 server.set_accept_state() 2461 2462 # Create the client 2463 clientContext = Context(TLSv1_METHOD) 2464 clientContext.set_verify(VERIFY_NONE, verify_cb) 2465 client = Connection(clientContext, None) 2466 client.set_connect_state() 2467 2468 interact_in_memory(client, server) 2469 2470 chain = client.get_peer_cert_chain() 2471 assert len(chain) == 3 2472 assert "Server Certificate" == chain[0].get_subject().CN 2473 assert "Intermediate Certificate" == chain[1].get_subject().CN 2474 assert "Authority Certificate" == chain[2].get_subject().CN 2475 2476 def test_get_peer_cert_chain_none(self): 2477 """ 2478 `Connection.get_peer_cert_chain` returns `None` if the peer sends 2479 no certificate chain. 2480 """ 2481 ctx = Context(TLSv1_METHOD) 2482 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 2483 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 2484 server = Connection(ctx, None) 2485 server.set_accept_state() 2486 client = Connection(Context(TLSv1_METHOD), None) 2487 client.set_connect_state() 2488 interact_in_memory(client, server) 2489 assert None is server.get_peer_cert_chain() 2490 2491 def test_get_session_unconnected(self): 2492 """ 2493 `Connection.get_session` returns `None` when used with an object 2494 which has not been connected. 2495 """ 2496 ctx = Context(TLSv1_METHOD) 2497 server = Connection(ctx, None) 2498 session = server.get_session() 2499 assert None is session 2500 2501 def test_server_get_session(self): 2502 """ 2503 On the server side of a connection, `Connection.get_session` returns a 2504 `Session` instance representing the SSL session for that connection. 2505 """ 2506 server, client = loopback() 2507 session = server.get_session() 2508 assert isinstance(session, Session) 2509 2510 def test_client_get_session(self): 2511 """ 2512 On the client side of a connection, `Connection.get_session` 2513 returns a `Session` instance representing the SSL session for 2514 that connection. 2515 """ 2516 server, client = loopback() 2517 session = client.get_session() 2518 assert isinstance(session, Session) 2519 2520 def test_set_session_wrong_args(self): 2521 """ 2522 `Connection.set_session` raises `TypeError` if called with an object 2523 that is not an instance of `Session`. 2524 """ 2525 ctx = Context(TLSv1_METHOD) 2526 connection = Connection(ctx, None) 2527 with pytest.raises(TypeError): 2528 connection.set_session(123) 2529 with pytest.raises(TypeError): 2530 connection.set_session("hello") 2531 with pytest.raises(TypeError): 2532 connection.set_session(object()) 2533 2534 def test_client_set_session(self): 2535 """ 2536 `Connection.set_session`, when used prior to a connection being 2537 established, accepts a `Session` instance and causes an attempt to 2538 re-use the session it represents when the SSL handshake is performed. 2539 """ 2540 key = load_privatekey(FILETYPE_PEM, server_key_pem) 2541 cert = load_certificate(FILETYPE_PEM, server_cert_pem) 2542 ctx = Context(TLSv1_2_METHOD) 2543 ctx.use_privatekey(key) 2544 ctx.use_certificate(cert) 2545 ctx.set_session_id("unity-test") 2546 2547 def makeServer(socket): 2548 server = Connection(ctx, socket) 2549 server.set_accept_state() 2550 return server 2551 2552 originalServer, originalClient = loopback( 2553 server_factory=makeServer) 2554 originalSession = originalClient.get_session() 2555 2556 def makeClient(socket): 2557 client = loopback_client_factory(socket) 2558 client.set_session(originalSession) 2559 return client 2560 resumedServer, resumedClient = loopback( 2561 server_factory=makeServer, 2562 client_factory=makeClient) 2563 2564 # This is a proxy: in general, we have no access to any unique 2565 # identifier for the session (new enough versions of OpenSSL expose 2566 # a hash which could be usable, but "new enough" is very, very new). 2567 # Instead, exploit the fact that the master key is re-used if the 2568 # session is re-used. As long as the master key for the two 2569 # connections is the same, the session was re-used! 2570 assert originalServer.master_key() == resumedServer.master_key() 2571 2572 def test_set_session_wrong_method(self): 2573 """ 2574 If `Connection.set_session` is passed a `Session` instance associated 2575 with a context using a different SSL method than the `Connection` 2576 is using, a `OpenSSL.SSL.Error` is raised. 2577 """ 2578 # Make this work on both OpenSSL 1.0.0, which doesn't support TLSv1.2 2579 # and also on OpenSSL 1.1.0 which doesn't support SSLv3. (SSL_ST_INIT 2580 # is a way to check for 1.1.0) 2581 if SSL_ST_INIT is None: 2582 v1 = TLSv1_2_METHOD 2583 v2 = TLSv1_METHOD 2584 elif hasattr(_lib, "SSLv3_method"): 2585 v1 = TLSv1_METHOD 2586 v2 = SSLv3_METHOD 2587 else: 2588 pytest.skip("Test requires either OpenSSL 1.1.0 or SSLv3") 2589 2590 key = load_privatekey(FILETYPE_PEM, server_key_pem) 2591 cert = load_certificate(FILETYPE_PEM, server_cert_pem) 2592 ctx = Context(v1) 2593 ctx.use_privatekey(key) 2594 ctx.use_certificate(cert) 2595 ctx.set_session_id("unity-test") 2596 2597 def makeServer(socket): 2598 server = Connection(ctx, socket) 2599 server.set_accept_state() 2600 return server 2601 2602 def makeOriginalClient(socket): 2603 client = Connection(Context(v1), socket) 2604 client.set_connect_state() 2605 return client 2606 2607 originalServer, originalClient = loopback( 2608 server_factory=makeServer, client_factory=makeOriginalClient) 2609 originalSession = originalClient.get_session() 2610 2611 def makeClient(socket): 2612 # Intentionally use a different, incompatible method here. 2613 client = Connection(Context(v2), socket) 2614 client.set_connect_state() 2615 client.set_session(originalSession) 2616 return client 2617 2618 with pytest.raises(Error): 2619 loopback(client_factory=makeClient, server_factory=makeServer) 2620 2621 def test_wantWriteError(self): 2622 """ 2623 `Connection` methods which generate output raise 2624 `OpenSSL.SSL.WantWriteError` if writing to the connection's BIO 2625 fail indicating a should-write state. 2626 """ 2627 client_socket, server_socket = socket_pair() 2628 # Fill up the client's send buffer so Connection won't be able to write 2629 # anything. Only write a single byte at a time so we can be sure we 2630 # completely fill the buffer. Even though the socket API is allowed to 2631 # signal a short write via its return value it seems this doesn't 2632 # always happen on all platforms (FreeBSD and OS X particular) for the 2633 # very last bit of available buffer space. 2634 msg = b"x" 2635 for i in range(1024 * 1024 * 64): 2636 try: 2637 client_socket.send(msg) 2638 except error as e: 2639 if e.errno == EWOULDBLOCK: 2640 break 2641 raise 2642 else: 2643 pytest.fail( 2644 "Failed to fill socket buffer, cannot test BIO want write") 2645 2646 ctx = Context(TLSv1_METHOD) 2647 conn = Connection(ctx, client_socket) 2648 # Client's speak first, so make it an SSL client 2649 conn.set_connect_state() 2650 with pytest.raises(WantWriteError): 2651 conn.do_handshake() 2652 2653 # XXX want_read 2654 2655 def test_get_finished_before_connect(self): 2656 """ 2657 `Connection.get_finished` returns `None` before TLS handshake 2658 is completed. 2659 """ 2660 ctx = Context(TLSv1_METHOD) 2661 connection = Connection(ctx, None) 2662 assert connection.get_finished() is None 2663 2664 def test_get_peer_finished_before_connect(self): 2665 """ 2666 `Connection.get_peer_finished` returns `None` before TLS handshake 2667 is completed. 2668 """ 2669 ctx = Context(TLSv1_METHOD) 2670 connection = Connection(ctx, None) 2671 assert connection.get_peer_finished() is None 2672 2673 def test_get_finished(self): 2674 """ 2675 `Connection.get_finished` method returns the TLS Finished message send 2676 from client, or server. Finished messages are send during 2677 TLS handshake. 2678 """ 2679 server, client = loopback() 2680 2681 assert server.get_finished() is not None 2682 assert len(server.get_finished()) > 0 2683 2684 def test_get_peer_finished(self): 2685 """ 2686 `Connection.get_peer_finished` method returns the TLS Finished 2687 message received from client, or server. Finished messages are send 2688 during TLS handshake. 2689 """ 2690 server, client = loopback() 2691 2692 assert server.get_peer_finished() is not None 2693 assert len(server.get_peer_finished()) > 0 2694 2695 def test_tls_finished_message_symmetry(self): 2696 """ 2697 The TLS Finished message send by server must be the TLS Finished 2698 message received by client. 2699 2700 The TLS Finished message send by client must be the TLS Finished 2701 message received by server. 2702 """ 2703 server, client = loopback() 2704 2705 assert server.get_finished() == client.get_peer_finished() 2706 assert client.get_finished() == server.get_peer_finished() 2707 2708 def test_get_cipher_name_before_connect(self): 2709 """ 2710 `Connection.get_cipher_name` returns `None` if no connection 2711 has been established. 2712 """ 2713 ctx = Context(TLSv1_METHOD) 2714 conn = Connection(ctx, None) 2715 assert conn.get_cipher_name() is None 2716 2717 def test_get_cipher_name(self): 2718 """ 2719 `Connection.get_cipher_name` returns a `unicode` string giving the 2720 name of the currently used cipher. 2721 """ 2722 server, client = loopback() 2723 server_cipher_name, client_cipher_name = \ 2724 server.get_cipher_name(), client.get_cipher_name() 2725 2726 assert isinstance(server_cipher_name, text_type) 2727 assert isinstance(client_cipher_name, text_type) 2728 2729 assert server_cipher_name == client_cipher_name 2730 2731 def test_get_cipher_version_before_connect(self): 2732 """ 2733 `Connection.get_cipher_version` returns `None` if no connection 2734 has been established. 2735 """ 2736 ctx = Context(TLSv1_METHOD) 2737 conn = Connection(ctx, None) 2738 assert conn.get_cipher_version() is None 2739 2740 def test_get_cipher_version(self): 2741 """ 2742 `Connection.get_cipher_version` returns a `unicode` string giving 2743 the protocol name of the currently used cipher. 2744 """ 2745 server, client = loopback() 2746 server_cipher_version, client_cipher_version = \ 2747 server.get_cipher_version(), client.get_cipher_version() 2748 2749 assert isinstance(server_cipher_version, text_type) 2750 assert isinstance(client_cipher_version, text_type) 2751 2752 assert server_cipher_version == client_cipher_version 2753 2754 def test_get_cipher_bits_before_connect(self): 2755 """ 2756 `Connection.get_cipher_bits` returns `None` if no connection has 2757 been established. 2758 """ 2759 ctx = Context(TLSv1_METHOD) 2760 conn = Connection(ctx, None) 2761 assert conn.get_cipher_bits() is None 2762 2763 def test_get_cipher_bits(self): 2764 """ 2765 `Connection.get_cipher_bits` returns the number of secret bits 2766 of the currently used cipher. 2767 """ 2768 server, client = loopback() 2769 server_cipher_bits, client_cipher_bits = \ 2770 server.get_cipher_bits(), client.get_cipher_bits() 2771 2772 assert isinstance(server_cipher_bits, int) 2773 assert isinstance(client_cipher_bits, int) 2774 2775 assert server_cipher_bits == client_cipher_bits 2776 2777 def test_get_protocol_version_name(self): 2778 """ 2779 `Connection.get_protocol_version_name()` returns a string giving the 2780 protocol version of the current connection. 2781 """ 2782 server, client = loopback() 2783 client_protocol_version_name = client.get_protocol_version_name() 2784 server_protocol_version_name = server.get_protocol_version_name() 2785 2786 assert isinstance(server_protocol_version_name, text_type) 2787 assert isinstance(client_protocol_version_name, text_type) 2788 2789 assert server_protocol_version_name == client_protocol_version_name 2790 2791 def test_get_protocol_version(self): 2792 """ 2793 `Connection.get_protocol_version()` returns an integer 2794 giving the protocol version of the current connection. 2795 """ 2796 server, client = loopback() 2797 client_protocol_version = client.get_protocol_version() 2798 server_protocol_version = server.get_protocol_version() 2799 2800 assert isinstance(server_protocol_version, int) 2801 assert isinstance(client_protocol_version, int) 2802 2803 assert server_protocol_version == client_protocol_version 2804 2805 def test_wantReadError(self): 2806 """ 2807 `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are 2808 no bytes available to be read from the BIO. 2809 """ 2810 ctx = Context(TLSv1_METHOD) 2811 conn = Connection(ctx, None) 2812 with pytest.raises(WantReadError): 2813 conn.bio_read(1024) 2814 2815 @pytest.mark.parametrize('bufsize', [1.0, None, object(), 'bufsize']) 2816 def test_bio_read_wrong_args(self, bufsize): 2817 """ 2818 `Connection.bio_read` raises `TypeError` if passed a non-integer 2819 argument. 2820 """ 2821 ctx = Context(TLSv1_METHOD) 2822 conn = Connection(ctx, None) 2823 with pytest.raises(TypeError): 2824 conn.bio_read(bufsize) 2825 2826 def test_buffer_size(self): 2827 """ 2828 `Connection.bio_read` accepts an integer giving the maximum number 2829 of bytes to read and return. 2830 """ 2831 ctx = Context(TLSv1_METHOD) 2832 conn = Connection(ctx, None) 2833 conn.set_connect_state() 2834 try: 2835 conn.do_handshake() 2836 except WantReadError: 2837 pass 2838 data = conn.bio_read(2) 2839 assert 2 == len(data) 2840 2841 @skip_if_py3 2842 def test_buffer_size_long(self): 2843 """ 2844 On Python 2 `Connection.bio_read` accepts values of type `long` as 2845 well as `int`. 2846 """ 2847 ctx = Context(TLSv1_METHOD) 2848 conn = Connection(ctx, None) 2849 conn.set_connect_state() 2850 try: 2851 conn.do_handshake() 2852 except WantReadError: 2853 pass 2854 data = conn.bio_read(long(2)) 2855 assert 2 == len(data) 2856 2857 2858class TestConnectionGetCipherList(object): 2859 """ 2860 Tests for `Connection.get_cipher_list`. 2861 """ 2862 def test_result(self): 2863 """ 2864 `Connection.get_cipher_list` returns a list of `bytes` giving the 2865 names of the ciphers which might be used. 2866 """ 2867 connection = Connection(Context(TLSv1_METHOD), None) 2868 ciphers = connection.get_cipher_list() 2869 assert isinstance(ciphers, list) 2870 for cipher in ciphers: 2871 assert isinstance(cipher, str) 2872 2873 2874class VeryLarge(bytes): 2875 """ 2876 Mock object so that we don't have to allocate 2**31 bytes 2877 """ 2878 def __len__(self): 2879 return 2**31 2880 2881 2882class TestConnectionSend(object): 2883 """ 2884 Tests for `Connection.send`. 2885 """ 2886 def test_wrong_args(self): 2887 """ 2888 When called with arguments other than string argument for its first 2889 parameter, `Connection.send` raises `TypeError`. 2890 """ 2891 connection = Connection(Context(TLSv1_METHOD), None) 2892 with pytest.raises(TypeError): 2893 connection.send(object()) 2894 2895 def test_short_bytes(self): 2896 """ 2897 When passed a short byte string, `Connection.send` transmits all of it 2898 and returns the number of bytes sent. 2899 """ 2900 server, client = loopback() 2901 count = server.send(b'xy') 2902 assert count == 2 2903 assert client.recv(2) == b'xy' 2904 2905 def test_text(self): 2906 """ 2907 When passed a text, `Connection.send` transmits all of it and 2908 returns the number of bytes sent. It also raises a DeprecationWarning. 2909 """ 2910 server, client = loopback() 2911 with pytest.warns(DeprecationWarning) as w: 2912 simplefilter("always") 2913 count = server.send(b"xy".decode("ascii")) 2914 assert ( 2915 "{0} for buf is no longer accepted, use bytes".format( 2916 WARNING_TYPE_EXPECTED 2917 ) == str(w[-1].message)) 2918 assert count == 2 2919 assert client.recv(2) == b'xy' 2920 2921 def test_short_memoryview(self): 2922 """ 2923 When passed a memoryview onto a small number of bytes, 2924 `Connection.send` transmits all of them and returns the number 2925 of bytes sent. 2926 """ 2927 server, client = loopback() 2928 count = server.send(memoryview(b'xy')) 2929 assert count == 2 2930 assert client.recv(2) == b'xy' 2931 2932 @skip_if_py3 2933 def test_short_buffer(self): 2934 """ 2935 When passed a buffer containing a small number of bytes, 2936 `Connection.send` transmits all of them and returns the number 2937 of bytes sent. 2938 """ 2939 server, client = loopback() 2940 count = server.send(buffer(b'xy')) 2941 assert count == 2 2942 assert client.recv(2) == b'xy' 2943 2944 @pytest.mark.skipif( 2945 sys.maxsize < 2**31, 2946 reason="sys.maxsize < 2**31 - test requires 64 bit" 2947 ) 2948 def test_buf_too_large(self): 2949 """ 2950 When passed a buffer containing >= 2**31 bytes, 2951 `Connection.send` bails out as SSL_write only 2952 accepts an int for the buffer length. 2953 """ 2954 connection = Connection(Context(TLSv1_METHOD), None) 2955 with pytest.raises(ValueError) as exc_info: 2956 connection.send(VeryLarge()) 2957 exc_info.match(r"Cannot send more than .+ bytes at once") 2958 2959 2960def _make_memoryview(size): 2961 """ 2962 Create a new ``memoryview`` wrapped around a ``bytearray`` of the given 2963 size. 2964 """ 2965 return memoryview(bytearray(size)) 2966 2967 2968class TestConnectionRecvInto(object): 2969 """ 2970 Tests for `Connection.recv_into`. 2971 """ 2972 def _no_length_test(self, factory): 2973 """ 2974 Assert that when the given buffer is passed to `Connection.recv_into`, 2975 whatever bytes are available to be received that fit into that buffer 2976 are written into that buffer. 2977 """ 2978 output_buffer = factory(5) 2979 2980 server, client = loopback() 2981 server.send(b'xy') 2982 2983 assert client.recv_into(output_buffer) == 2 2984 assert output_buffer == bytearray(b'xy\x00\x00\x00') 2985 2986 def test_bytearray_no_length(self): 2987 """ 2988 `Connection.recv_into` can be passed a `bytearray` instance and data 2989 in the receive buffer is written to it. 2990 """ 2991 self._no_length_test(bytearray) 2992 2993 def _respects_length_test(self, factory): 2994 """ 2995 Assert that when the given buffer is passed to `Connection.recv_into` 2996 along with a value for `nbytes` that is less than the size of that 2997 buffer, only `nbytes` bytes are written into the buffer. 2998 """ 2999 output_buffer = factory(10) 3000 3001 server, client = loopback() 3002 server.send(b'abcdefghij') 3003 3004 assert client.recv_into(output_buffer, 5) == 5 3005 assert output_buffer == bytearray(b'abcde\x00\x00\x00\x00\x00') 3006 3007 def test_bytearray_respects_length(self): 3008 """ 3009 When called with a `bytearray` instance, `Connection.recv_into` 3010 respects the `nbytes` parameter and doesn't copy in more than that 3011 number of bytes. 3012 """ 3013 self._respects_length_test(bytearray) 3014 3015 def _doesnt_overfill_test(self, factory): 3016 """ 3017 Assert that if there are more bytes available to be read from the 3018 receive buffer than would fit into the buffer passed to 3019 `Connection.recv_into`, only as many as fit are written into it. 3020 """ 3021 output_buffer = factory(5) 3022 3023 server, client = loopback() 3024 server.send(b'abcdefghij') 3025 3026 assert client.recv_into(output_buffer) == 5 3027 assert output_buffer == bytearray(b'abcde') 3028 rest = client.recv(5) 3029 assert b'fghij' == rest 3030 3031 def test_bytearray_doesnt_overfill(self): 3032 """ 3033 When called with a `bytearray` instance, `Connection.recv_into` 3034 respects the size of the array and doesn't write more bytes into it 3035 than will fit. 3036 """ 3037 self._doesnt_overfill_test(bytearray) 3038 3039 def test_bytearray_really_doesnt_overfill(self): 3040 """ 3041 When called with a `bytearray` instance and an `nbytes` value that is 3042 too large, `Connection.recv_into` respects the size of the array and 3043 not the `nbytes` value and doesn't write more bytes into the buffer 3044 than will fit. 3045 """ 3046 self._doesnt_overfill_test(bytearray) 3047 3048 def test_peek(self): 3049 server, client = loopback() 3050 server.send(b'xy') 3051 3052 for _ in range(2): 3053 output_buffer = bytearray(5) 3054 assert client.recv_into(output_buffer, flags=MSG_PEEK) == 2 3055 assert output_buffer == bytearray(b'xy\x00\x00\x00') 3056 3057 def test_memoryview_no_length(self): 3058 """ 3059 `Connection.recv_into` can be passed a `memoryview` instance and data 3060 in the receive buffer is written to it. 3061 """ 3062 self._no_length_test(_make_memoryview) 3063 3064 def test_memoryview_respects_length(self): 3065 """ 3066 When called with a `memoryview` instance, `Connection.recv_into` 3067 respects the ``nbytes`` parameter and doesn't copy more than that 3068 number of bytes in. 3069 """ 3070 self._respects_length_test(_make_memoryview) 3071 3072 def test_memoryview_doesnt_overfill(self): 3073 """ 3074 When called with a `memoryview` instance, `Connection.recv_into` 3075 respects the size of the array and doesn't write more bytes into it 3076 than will fit. 3077 """ 3078 self._doesnt_overfill_test(_make_memoryview) 3079 3080 def test_memoryview_really_doesnt_overfill(self): 3081 """ 3082 When called with a `memoryview` instance and an `nbytes` value that is 3083 too large, `Connection.recv_into` respects the size of the array and 3084 not the `nbytes` value and doesn't write more bytes into the buffer 3085 than will fit. 3086 """ 3087 self._doesnt_overfill_test(_make_memoryview) 3088 3089 3090class TestConnectionSendall(object): 3091 """ 3092 Tests for `Connection.sendall`. 3093 """ 3094 def test_wrong_args(self): 3095 """ 3096 When called with arguments other than a string argument for its first 3097 parameter, `Connection.sendall` raises `TypeError`. 3098 """ 3099 connection = Connection(Context(TLSv1_METHOD), None) 3100 with pytest.raises(TypeError): 3101 connection.sendall(object()) 3102 3103 def test_short(self): 3104 """ 3105 `Connection.sendall` transmits all of the bytes in the string 3106 passed to it. 3107 """ 3108 server, client = loopback() 3109 server.sendall(b'x') 3110 assert client.recv(1) == b'x' 3111 3112 def test_text(self): 3113 """ 3114 `Connection.sendall` transmits all the content in the string passed 3115 to it, raising a DeprecationWarning in case of this being a text. 3116 """ 3117 server, client = loopback() 3118 with pytest.warns(DeprecationWarning) as w: 3119 simplefilter("always") 3120 server.sendall(b"x".decode("ascii")) 3121 assert ( 3122 "{0} for buf is no longer accepted, use bytes".format( 3123 WARNING_TYPE_EXPECTED 3124 ) == str(w[-1].message)) 3125 assert client.recv(1) == b"x" 3126 3127 def test_short_memoryview(self): 3128 """ 3129 When passed a memoryview onto a small number of bytes, 3130 `Connection.sendall` transmits all of them. 3131 """ 3132 server, client = loopback() 3133 server.sendall(memoryview(b'x')) 3134 assert client.recv(1) == b'x' 3135 3136 @skip_if_py3 3137 def test_short_buffers(self): 3138 """ 3139 When passed a buffer containing a small number of bytes, 3140 `Connection.sendall` transmits all of them. 3141 """ 3142 server, client = loopback() 3143 server.sendall(buffer(b'x')) 3144 assert client.recv(1) == b'x' 3145 3146 def test_long(self): 3147 """ 3148 `Connection.sendall` transmits all the bytes in the string passed to it 3149 even if this requires multiple calls of an underlying write function. 3150 """ 3151 server, client = loopback() 3152 # Should be enough, underlying SSL_write should only do 16k at a time. 3153 # On Windows, after 32k of bytes the write will block (forever 3154 # - because no one is yet reading). 3155 message = b'x' * (1024 * 32 - 1) + b'y' 3156 server.sendall(message) 3157 accum = [] 3158 received = 0 3159 while received < len(message): 3160 data = client.recv(1024) 3161 accum.append(data) 3162 received += len(data) 3163 assert message == b''.join(accum) 3164 3165 def test_closed(self): 3166 """ 3167 If the underlying socket is closed, `Connection.sendall` propagates the 3168 write error from the low level write call. 3169 """ 3170 server, client = loopback() 3171 server.sock_shutdown(2) 3172 with pytest.raises(SysCallError) as err: 3173 server.sendall(b"hello, world") 3174 if platform == "win32": 3175 assert err.value.args[0] == ESHUTDOWN 3176 else: 3177 assert err.value.args[0] == EPIPE 3178 3179 3180class TestConnectionRenegotiate(object): 3181 """ 3182 Tests for SSL renegotiation APIs. 3183 """ 3184 def test_total_renegotiations(self): 3185 """ 3186 `Connection.total_renegotiations` returns `0` before any renegotiations 3187 have happened. 3188 """ 3189 connection = Connection(Context(TLSv1_METHOD), None) 3190 assert connection.total_renegotiations() == 0 3191 3192 def test_renegotiate(self): 3193 """ 3194 Go through a complete renegotiation cycle. 3195 """ 3196 server, client = loopback( 3197 lambda s: loopback_server_factory(s, TLSv1_2_METHOD), 3198 lambda s: loopback_client_factory(s, TLSv1_2_METHOD), 3199 ) 3200 3201 server.send(b"hello world") 3202 3203 assert b"hello world" == client.recv(len(b"hello world")) 3204 3205 assert 0 == server.total_renegotiations() 3206 assert False is server.renegotiate_pending() 3207 3208 assert True is server.renegotiate() 3209 3210 assert True is server.renegotiate_pending() 3211 3212 server.setblocking(False) 3213 client.setblocking(False) 3214 3215 client.do_handshake() 3216 server.do_handshake() 3217 3218 assert 1 == server.total_renegotiations() 3219 while False is server.renegotiate_pending(): 3220 pass 3221 3222 3223class TestError(object): 3224 """ 3225 Unit tests for `OpenSSL.SSL.Error`. 3226 """ 3227 def test_type(self): 3228 """ 3229 `Error` is an exception type. 3230 """ 3231 assert issubclass(Error, Exception) 3232 assert Error.__name__ == 'Error' 3233 3234 3235class TestConstants(object): 3236 """ 3237 Tests for the values of constants exposed in `OpenSSL.SSL`. 3238 3239 These are values defined by OpenSSL intended only to be used as flags to 3240 OpenSSL APIs. The only assertions it seems can be made about them is 3241 their values. 3242 """ 3243 @pytest.mark.skipif( 3244 OP_NO_QUERY_MTU is None, 3245 reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old" 3246 ) 3247 def test_op_no_query_mtu(self): 3248 """ 3249 The value of `OpenSSL.SSL.OP_NO_QUERY_MTU` is 0x1000, the value 3250 of `SSL_OP_NO_QUERY_MTU` defined by `openssl/ssl.h`. 3251 """ 3252 assert OP_NO_QUERY_MTU == 0x1000 3253 3254 @pytest.mark.skipif( 3255 OP_COOKIE_EXCHANGE is None, 3256 reason="OP_COOKIE_EXCHANGE unavailable - " 3257 "OpenSSL version may be too old" 3258 ) 3259 def test_op_cookie_exchange(self): 3260 """ 3261 The value of `OpenSSL.SSL.OP_COOKIE_EXCHANGE` is 0x2000, the 3262 value of `SSL_OP_COOKIE_EXCHANGE` defined by `openssl/ssl.h`. 3263 """ 3264 assert OP_COOKIE_EXCHANGE == 0x2000 3265 3266 @pytest.mark.skipif( 3267 OP_NO_TICKET is None, 3268 reason="OP_NO_TICKET unavailable - OpenSSL version may be too old" 3269 ) 3270 def test_op_no_ticket(self): 3271 """ 3272 The value of `OpenSSL.SSL.OP_NO_TICKET` is 0x4000, the value of 3273 `SSL_OP_NO_TICKET` defined by `openssl/ssl.h`. 3274 """ 3275 assert OP_NO_TICKET == 0x4000 3276 3277 @pytest.mark.skipif( 3278 OP_NO_COMPRESSION is None, 3279 reason="OP_NO_COMPRESSION unavailable - OpenSSL version may be too old" 3280 ) 3281 def test_op_no_compression(self): 3282 """ 3283 The value of `OpenSSL.SSL.OP_NO_COMPRESSION` is 0x20000, the 3284 value of `SSL_OP_NO_COMPRESSION` defined by `openssl/ssl.h`. 3285 """ 3286 assert OP_NO_COMPRESSION == 0x20000 3287 3288 def test_sess_cache_off(self): 3289 """ 3290 The value of `OpenSSL.SSL.SESS_CACHE_OFF` 0x0, the value of 3291 `SSL_SESS_CACHE_OFF` defined by `openssl/ssl.h`. 3292 """ 3293 assert 0x0 == SESS_CACHE_OFF 3294 3295 def test_sess_cache_client(self): 3296 """ 3297 The value of `OpenSSL.SSL.SESS_CACHE_CLIENT` 0x1, the value of 3298 `SSL_SESS_CACHE_CLIENT` defined by `openssl/ssl.h`. 3299 """ 3300 assert 0x1 == SESS_CACHE_CLIENT 3301 3302 def test_sess_cache_server(self): 3303 """ 3304 The value of `OpenSSL.SSL.SESS_CACHE_SERVER` 0x2, the value of 3305 `SSL_SESS_CACHE_SERVER` defined by `openssl/ssl.h`. 3306 """ 3307 assert 0x2 == SESS_CACHE_SERVER 3308 3309 def test_sess_cache_both(self): 3310 """ 3311 The value of `OpenSSL.SSL.SESS_CACHE_BOTH` 0x3, the value of 3312 `SSL_SESS_CACHE_BOTH` defined by `openssl/ssl.h`. 3313 """ 3314 assert 0x3 == SESS_CACHE_BOTH 3315 3316 def test_sess_cache_no_auto_clear(self): 3317 """ 3318 The value of `OpenSSL.SSL.SESS_CACHE_NO_AUTO_CLEAR` 0x80, the 3319 value of `SSL_SESS_CACHE_NO_AUTO_CLEAR` defined by 3320 `openssl/ssl.h`. 3321 """ 3322 assert 0x80 == SESS_CACHE_NO_AUTO_CLEAR 3323 3324 def test_sess_cache_no_internal_lookup(self): 3325 """ 3326 The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_LOOKUP` 0x100, 3327 the value of `SSL_SESS_CACHE_NO_INTERNAL_LOOKUP` defined by 3328 `openssl/ssl.h`. 3329 """ 3330 assert 0x100 == SESS_CACHE_NO_INTERNAL_LOOKUP 3331 3332 def test_sess_cache_no_internal_store(self): 3333 """ 3334 The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_STORE` 0x200, 3335 the value of `SSL_SESS_CACHE_NO_INTERNAL_STORE` defined by 3336 `openssl/ssl.h`. 3337 """ 3338 assert 0x200 == SESS_CACHE_NO_INTERNAL_STORE 3339 3340 def test_sess_cache_no_internal(self): 3341 """ 3342 The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL` 0x300, the 3343 value of `SSL_SESS_CACHE_NO_INTERNAL` defined by 3344 `openssl/ssl.h`. 3345 """ 3346 assert 0x300 == SESS_CACHE_NO_INTERNAL 3347 3348 3349class TestMemoryBIO(object): 3350 """ 3351 Tests for `OpenSSL.SSL.Connection` using a memory BIO. 3352 """ 3353 def _server(self, sock): 3354 """ 3355 Create a new server-side SSL `Connection` object wrapped around `sock`. 3356 """ 3357 # Create the server side Connection. This is mostly setup boilerplate 3358 # - use TLSv1, use a particular certificate, etc. 3359 server_ctx = Context(TLSv1_METHOD) 3360 server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) 3361 server_ctx.set_verify( 3362 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, 3363 verify_cb 3364 ) 3365 server_store = server_ctx.get_cert_store() 3366 server_ctx.use_privatekey( 3367 load_privatekey(FILETYPE_PEM, server_key_pem)) 3368 server_ctx.use_certificate( 3369 load_certificate(FILETYPE_PEM, server_cert_pem)) 3370 server_ctx.check_privatekey() 3371 server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) 3372 # Here the Connection is actually created. If None is passed as the 3373 # 2nd parameter, it indicates a memory BIO should be created. 3374 server_conn = Connection(server_ctx, sock) 3375 server_conn.set_accept_state() 3376 return server_conn 3377 3378 def _client(self, sock): 3379 """ 3380 Create a new client-side SSL `Connection` object wrapped around `sock`. 3381 """ 3382 # Now create the client side Connection. Similar boilerplate to the 3383 # above. 3384 client_ctx = Context(TLSv1_METHOD) 3385 client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) 3386 client_ctx.set_verify( 3387 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, 3388 verify_cb 3389 ) 3390 client_store = client_ctx.get_cert_store() 3391 client_ctx.use_privatekey( 3392 load_privatekey(FILETYPE_PEM, client_key_pem)) 3393 client_ctx.use_certificate( 3394 load_certificate(FILETYPE_PEM, client_cert_pem)) 3395 client_ctx.check_privatekey() 3396 client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) 3397 client_conn = Connection(client_ctx, sock) 3398 client_conn.set_connect_state() 3399 return client_conn 3400 3401 def test_memory_connect(self): 3402 """ 3403 Two `Connection`s which use memory BIOs can be manually connected by 3404 reading from the output of each and writing those bytes to the input of 3405 the other and in this way establish a connection and exchange 3406 application-level bytes with each other. 3407 """ 3408 server_conn = self._server(None) 3409 client_conn = self._client(None) 3410 3411 # There should be no key or nonces yet. 3412 assert server_conn.master_key() is None 3413 assert server_conn.client_random() is None 3414 assert server_conn.server_random() is None 3415 3416 # First, the handshake needs to happen. We'll deliver bytes back and 3417 # forth between the client and server until neither of them feels like 3418 # speaking any more. 3419 assert interact_in_memory(client_conn, server_conn) is None 3420 3421 # Now that the handshake is done, there should be a key and nonces. 3422 assert server_conn.master_key() is not None 3423 assert server_conn.client_random() is not None 3424 assert server_conn.server_random() is not None 3425 assert server_conn.client_random() == client_conn.client_random() 3426 assert server_conn.server_random() == client_conn.server_random() 3427 assert server_conn.client_random() != server_conn.server_random() 3428 assert client_conn.client_random() != client_conn.server_random() 3429 3430 # Export key material for other uses. 3431 cekm = client_conn.export_keying_material(b'LABEL', 32) 3432 sekm = server_conn.export_keying_material(b'LABEL', 32) 3433 assert cekm is not None 3434 assert sekm is not None 3435 assert cekm == sekm 3436 assert len(sekm) == 32 3437 3438 # Export key material for other uses with additional context. 3439 cekmc = client_conn.export_keying_material(b'LABEL', 32, b'CONTEXT') 3440 sekmc = server_conn.export_keying_material(b'LABEL', 32, b'CONTEXT') 3441 assert cekmc is not None 3442 assert sekmc is not None 3443 assert cekmc == sekmc 3444 assert cekmc != cekm 3445 assert sekmc != sekm 3446 # Export with alternate label 3447 cekmt = client_conn.export_keying_material(b'test', 32, b'CONTEXT') 3448 sekmt = server_conn.export_keying_material(b'test', 32, b'CONTEXT') 3449 assert cekmc != cekmt 3450 assert sekmc != sekmt 3451 3452 # Here are the bytes we'll try to send. 3453 important_message = b'One if by land, two if by sea.' 3454 3455 server_conn.write(important_message) 3456 assert ( 3457 interact_in_memory(client_conn, server_conn) == 3458 (client_conn, important_message)) 3459 3460 client_conn.write(important_message[::-1]) 3461 assert ( 3462 interact_in_memory(client_conn, server_conn) == 3463 (server_conn, important_message[::-1])) 3464 3465 def test_socket_connect(self): 3466 """ 3467 Just like `test_memory_connect` but with an actual socket. 3468 3469 This is primarily to rule out the memory BIO code as the source of any 3470 problems encountered while passing data over a `Connection` (if 3471 this test fails, there must be a problem outside the memory BIO code, 3472 as no memory BIO is involved here). Even though this isn't a memory 3473 BIO test, it's convenient to have it here. 3474 """ 3475 server_conn, client_conn = loopback() 3476 3477 important_message = b"Help me Obi Wan Kenobi, you're my only hope." 3478 client_conn.send(important_message) 3479 msg = server_conn.recv(1024) 3480 assert msg == important_message 3481 3482 # Again in the other direction, just for fun. 3483 important_message = important_message[::-1] 3484 server_conn.send(important_message) 3485 msg = client_conn.recv(1024) 3486 assert msg == important_message 3487 3488 def test_socket_overrides_memory(self): 3489 """ 3490 Test that `OpenSSL.SSL.bio_read` and `OpenSSL.SSL.bio_write` don't 3491 work on `OpenSSL.SSL.Connection`() that use sockets. 3492 """ 3493 context = Context(TLSv1_METHOD) 3494 client = socket() 3495 clientSSL = Connection(context, client) 3496 with pytest.raises(TypeError): 3497 clientSSL.bio_read(100) 3498 with pytest.raises(TypeError): 3499 clientSSL.bio_write("foo") 3500 with pytest.raises(TypeError): 3501 clientSSL.bio_shutdown() 3502 3503 def test_outgoing_overflow(self): 3504 """ 3505 If more bytes than can be written to the memory BIO are passed to 3506 `Connection.send` at once, the number of bytes which were written is 3507 returned and that many bytes from the beginning of the input can be 3508 read from the other end of the connection. 3509 """ 3510 server = self._server(None) 3511 client = self._client(None) 3512 3513 interact_in_memory(client, server) 3514 3515 size = 2 ** 15 3516 sent = client.send(b"x" * size) 3517 # Sanity check. We're trying to test what happens when the entire 3518 # input can't be sent. If the entire input was sent, this test is 3519 # meaningless. 3520 assert sent < size 3521 3522 receiver, received = interact_in_memory(client, server) 3523 assert receiver is server 3524 3525 # We can rely on all of these bytes being received at once because 3526 # loopback passes 2 ** 16 to recv - more than 2 ** 15. 3527 assert len(received) == sent 3528 3529 def test_shutdown(self): 3530 """ 3531 `Connection.bio_shutdown` signals the end of the data stream 3532 from which the `Connection` reads. 3533 """ 3534 server = self._server(None) 3535 server.bio_shutdown() 3536 with pytest.raises(Error) as err: 3537 server.recv(1024) 3538 # We don't want WantReadError or ZeroReturnError or anything - it's a 3539 # handshake failure. 3540 assert type(err.value) in [Error, SysCallError] 3541 3542 def test_unexpected_EOF(self): 3543 """ 3544 If the connection is lost before an orderly SSL shutdown occurs, 3545 `OpenSSL.SSL.SysCallError` is raised with a message of 3546 "Unexpected EOF". 3547 """ 3548 server_conn, client_conn = loopback() 3549 client_conn.sock_shutdown(SHUT_RDWR) 3550 with pytest.raises(SysCallError) as err: 3551 server_conn.recv(1024) 3552 assert err.value.args == (-1, "Unexpected EOF") 3553 3554 def _check_client_ca_list(self, func): 3555 """ 3556 Verify the return value of the `get_client_ca_list` method for 3557 server and client connections. 3558 3559 :param func: A function which will be called with the server context 3560 before the client and server are connected to each other. This 3561 function should specify a list of CAs for the server to send to the 3562 client and return that same list. The list will be used to verify 3563 that `get_client_ca_list` returns the proper value at 3564 various times. 3565 """ 3566 server = self._server(None) 3567 client = self._client(None) 3568 assert client.get_client_ca_list() == [] 3569 assert server.get_client_ca_list() == [] 3570 ctx = server.get_context() 3571 expected = func(ctx) 3572 assert client.get_client_ca_list() == [] 3573 assert server.get_client_ca_list() == expected 3574 interact_in_memory(client, server) 3575 assert client.get_client_ca_list() == expected 3576 assert server.get_client_ca_list() == expected 3577 3578 def test_set_client_ca_list_errors(self): 3579 """ 3580 `Context.set_client_ca_list` raises a `TypeError` if called with a 3581 non-list or a list that contains objects other than X509Names. 3582 """ 3583 ctx = Context(TLSv1_METHOD) 3584 with pytest.raises(TypeError): 3585 ctx.set_client_ca_list("spam") 3586 with pytest.raises(TypeError): 3587 ctx.set_client_ca_list(["spam"]) 3588 3589 def test_set_empty_ca_list(self): 3590 """ 3591 If passed an empty list, `Context.set_client_ca_list` configures the 3592 context to send no CA names to the client and, on both the server and 3593 client sides, `Connection.get_client_ca_list` returns an empty list 3594 after the connection is set up. 3595 """ 3596 def no_ca(ctx): 3597 ctx.set_client_ca_list([]) 3598 return [] 3599 self._check_client_ca_list(no_ca) 3600 3601 def test_set_one_ca_list(self): 3602 """ 3603 If passed a list containing a single X509Name, 3604 `Context.set_client_ca_list` configures the context to send 3605 that CA name to the client and, on both the server and client sides, 3606 `Connection.get_client_ca_list` returns a list containing that 3607 X509Name after the connection is set up. 3608 """ 3609 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3610 cadesc = cacert.get_subject() 3611 3612 def single_ca(ctx): 3613 ctx.set_client_ca_list([cadesc]) 3614 return [cadesc] 3615 self._check_client_ca_list(single_ca) 3616 3617 def test_set_multiple_ca_list(self): 3618 """ 3619 If passed a list containing multiple X509Name objects, 3620 `Context.set_client_ca_list` configures the context to send 3621 those CA names to the client and, on both the server and client sides, 3622 `Connection.get_client_ca_list` returns a list containing those 3623 X509Names after the connection is set up. 3624 """ 3625 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3626 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3627 3628 sedesc = secert.get_subject() 3629 cldesc = clcert.get_subject() 3630 3631 def multiple_ca(ctx): 3632 L = [sedesc, cldesc] 3633 ctx.set_client_ca_list(L) 3634 return L 3635 self._check_client_ca_list(multiple_ca) 3636 3637 def test_reset_ca_list(self): 3638 """ 3639 If called multiple times, only the X509Names passed to the final call 3640 of `Context.set_client_ca_list` are used to configure the CA 3641 names sent to the client. 3642 """ 3643 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3644 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3645 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3646 3647 cadesc = cacert.get_subject() 3648 sedesc = secert.get_subject() 3649 cldesc = clcert.get_subject() 3650 3651 def changed_ca(ctx): 3652 ctx.set_client_ca_list([sedesc, cldesc]) 3653 ctx.set_client_ca_list([cadesc]) 3654 return [cadesc] 3655 self._check_client_ca_list(changed_ca) 3656 3657 def test_mutated_ca_list(self): 3658 """ 3659 If the list passed to `Context.set_client_ca_list` is mutated 3660 afterwards, this does not affect the list of CA names sent to the 3661 client. 3662 """ 3663 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3664 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3665 3666 cadesc = cacert.get_subject() 3667 sedesc = secert.get_subject() 3668 3669 def mutated_ca(ctx): 3670 L = [cadesc] 3671 ctx.set_client_ca_list([cadesc]) 3672 L.append(sedesc) 3673 return [cadesc] 3674 self._check_client_ca_list(mutated_ca) 3675 3676 def test_add_client_ca_wrong_args(self): 3677 """ 3678 `Context.add_client_ca` raises `TypeError` if called with 3679 a non-X509 object. 3680 """ 3681 ctx = Context(TLSv1_METHOD) 3682 with pytest.raises(TypeError): 3683 ctx.add_client_ca("spam") 3684 3685 def test_one_add_client_ca(self): 3686 """ 3687 A certificate's subject can be added as a CA to be sent to the client 3688 with `Context.add_client_ca`. 3689 """ 3690 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3691 cadesc = cacert.get_subject() 3692 3693 def single_ca(ctx): 3694 ctx.add_client_ca(cacert) 3695 return [cadesc] 3696 self._check_client_ca_list(single_ca) 3697 3698 def test_multiple_add_client_ca(self): 3699 """ 3700 Multiple CA names can be sent to the client by calling 3701 `Context.add_client_ca` with multiple X509 objects. 3702 """ 3703 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3704 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3705 3706 cadesc = cacert.get_subject() 3707 sedesc = secert.get_subject() 3708 3709 def multiple_ca(ctx): 3710 ctx.add_client_ca(cacert) 3711 ctx.add_client_ca(secert) 3712 return [cadesc, sedesc] 3713 self._check_client_ca_list(multiple_ca) 3714 3715 def test_set_and_add_client_ca(self): 3716 """ 3717 A call to `Context.set_client_ca_list` followed by a call to 3718 `Context.add_client_ca` results in using the CA names from the 3719 first call and the CA name from the second call. 3720 """ 3721 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3722 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3723 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3724 3725 cadesc = cacert.get_subject() 3726 sedesc = secert.get_subject() 3727 cldesc = clcert.get_subject() 3728 3729 def mixed_set_add_ca(ctx): 3730 ctx.set_client_ca_list([cadesc, sedesc]) 3731 ctx.add_client_ca(clcert) 3732 return [cadesc, sedesc, cldesc] 3733 self._check_client_ca_list(mixed_set_add_ca) 3734 3735 def test_set_after_add_client_ca(self): 3736 """ 3737 A call to `Context.set_client_ca_list` after a call to 3738 `Context.add_client_ca` replaces the CA name specified by the 3739 former call with the names specified by the latter call. 3740 """ 3741 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3742 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3743 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3744 3745 cadesc = cacert.get_subject() 3746 sedesc = secert.get_subject() 3747 3748 def set_replaces_add_ca(ctx): 3749 ctx.add_client_ca(clcert) 3750 ctx.set_client_ca_list([cadesc]) 3751 ctx.add_client_ca(secert) 3752 return [cadesc, sedesc] 3753 self._check_client_ca_list(set_replaces_add_ca) 3754 3755 3756class TestInfoConstants(object): 3757 """ 3758 Tests for assorted constants exposed for use in info callbacks. 3759 """ 3760 def test_integers(self): 3761 """ 3762 All of the info constants are integers. 3763 3764 This is a very weak test. It would be nice to have one that actually 3765 verifies that as certain info events happen, the value passed to the 3766 info callback matches up with the constant exposed by OpenSSL.SSL. 3767 """ 3768 for const in [ 3769 SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, 3770 SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT, 3771 SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP, 3772 SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT, 3773 SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE 3774 ]: 3775 assert isinstance(const, int) 3776 3777 # These constants don't exist on OpenSSL 1.1.0 3778 for const in [ 3779 SSL_ST_INIT, SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE 3780 ]: 3781 assert const is None or isinstance(const, int) 3782 3783 3784class TestRequires(object): 3785 """ 3786 Tests for the decorator factory used to conditionally raise 3787 NotImplementedError when older OpenSSLs are used. 3788 """ 3789 def test_available(self): 3790 """ 3791 When the OpenSSL functionality is available the decorated functions 3792 work appropriately. 3793 """ 3794 feature_guard = _make_requires(True, "Error text") 3795 results = [] 3796 3797 @feature_guard 3798 def inner(): 3799 results.append(True) 3800 return True 3801 3802 assert inner() is True 3803 assert [True] == results 3804 3805 def test_unavailable(self): 3806 """ 3807 When the OpenSSL functionality is not available the decorated function 3808 does not execute and NotImplementedError is raised. 3809 """ 3810 feature_guard = _make_requires(False, "Error text") 3811 3812 @feature_guard 3813 def inner(): # pragma: nocover 3814 pytest.fail("Should not be called") 3815 3816 with pytest.raises(NotImplementedError) as e: 3817 inner() 3818 3819 assert "Error text" in str(e.value) 3820 3821 3822class TestOCSP(object): 3823 """ 3824 Tests for PyOpenSSL's OCSP stapling support. 3825 """ 3826 sample_ocsp_data = b"this is totally ocsp data" 3827 3828 def _client_connection(self, callback, data, request_ocsp=True): 3829 """ 3830 Builds a client connection suitable for using OCSP. 3831 3832 :param callback: The callback to register for OCSP. 3833 :param data: The opaque data object that will be handed to the 3834 OCSP callback. 3835 :param request_ocsp: Whether the client will actually ask for OCSP 3836 stapling. Useful for testing only. 3837 """ 3838 ctx = Context(SSLv23_METHOD) 3839 ctx.set_ocsp_client_callback(callback, data) 3840 client = Connection(ctx) 3841 3842 if request_ocsp: 3843 client.request_ocsp() 3844 3845 client.set_connect_state() 3846 return client 3847 3848 def _server_connection(self, callback, data): 3849 """ 3850 Builds a server connection suitable for using OCSP. 3851 3852 :param callback: The callback to register for OCSP. 3853 :param data: The opaque data object that will be handed to the 3854 OCSP callback. 3855 """ 3856 ctx = Context(SSLv23_METHOD) 3857 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 3858 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 3859 ctx.set_ocsp_server_callback(callback, data) 3860 server = Connection(ctx) 3861 server.set_accept_state() 3862 return server 3863 3864 def test_callbacks_arent_called_by_default(self): 3865 """ 3866 If both the client and the server have registered OCSP callbacks, but 3867 the client does not send the OCSP request, neither callback gets 3868 called. 3869 """ 3870 def ocsp_callback(*args, **kwargs): # pragma: nocover 3871 pytest.fail("Should not be called") 3872 3873 client = self._client_connection( 3874 callback=ocsp_callback, data=None, request_ocsp=False 3875 ) 3876 server = self._server_connection(callback=ocsp_callback, data=None) 3877 handshake_in_memory(client, server) 3878 3879 def test_client_negotiates_without_server(self): 3880 """ 3881 If the client wants to do OCSP but the server does not, the handshake 3882 succeeds, and the client callback fires with an empty byte string. 3883 """ 3884 called = [] 3885 3886 def ocsp_callback(conn, ocsp_data, ignored): 3887 called.append(ocsp_data) 3888 return True 3889 3890 client = self._client_connection(callback=ocsp_callback, data=None) 3891 server = loopback_server_factory(socket=None) 3892 handshake_in_memory(client, server) 3893 3894 assert len(called) == 1 3895 assert called[0] == b'' 3896 3897 def test_client_receives_servers_data(self): 3898 """ 3899 The data the server sends in its callback is received by the client. 3900 """ 3901 calls = [] 3902 3903 def server_callback(*args, **kwargs): 3904 return self.sample_ocsp_data 3905 3906 def client_callback(conn, ocsp_data, ignored): 3907 calls.append(ocsp_data) 3908 return True 3909 3910 client = self._client_connection(callback=client_callback, data=None) 3911 server = self._server_connection(callback=server_callback, data=None) 3912 handshake_in_memory(client, server) 3913 3914 assert len(calls) == 1 3915 assert calls[0] == self.sample_ocsp_data 3916 3917 def test_callbacks_are_invoked_with_connections(self): 3918 """ 3919 The first arguments to both callbacks are their respective connections. 3920 """ 3921 client_calls = [] 3922 server_calls = [] 3923 3924 def client_callback(conn, *args, **kwargs): 3925 client_calls.append(conn) 3926 return True 3927 3928 def server_callback(conn, *args, **kwargs): 3929 server_calls.append(conn) 3930 return self.sample_ocsp_data 3931 3932 client = self._client_connection(callback=client_callback, data=None) 3933 server = self._server_connection(callback=server_callback, data=None) 3934 handshake_in_memory(client, server) 3935 3936 assert len(client_calls) == 1 3937 assert len(server_calls) == 1 3938 assert client_calls[0] is client 3939 assert server_calls[0] is server 3940 3941 def test_opaque_data_is_passed_through(self): 3942 """ 3943 Both callbacks receive an opaque, user-provided piece of data in their 3944 callbacks as the final argument. 3945 """ 3946 calls = [] 3947 3948 def server_callback(*args): 3949 calls.append(args) 3950 return self.sample_ocsp_data 3951 3952 def client_callback(*args): 3953 calls.append(args) 3954 return True 3955 3956 sentinel = object() 3957 3958 client = self._client_connection( 3959 callback=client_callback, data=sentinel 3960 ) 3961 server = self._server_connection( 3962 callback=server_callback, data=sentinel 3963 ) 3964 handshake_in_memory(client, server) 3965 3966 assert len(calls) == 2 3967 assert calls[0][-1] is sentinel 3968 assert calls[1][-1] is sentinel 3969 3970 def test_server_returns_empty_string(self): 3971 """ 3972 If the server returns an empty bytestring from its callback, the 3973 client callback is called with the empty bytestring. 3974 """ 3975 client_calls = [] 3976 3977 def server_callback(*args): 3978 return b'' 3979 3980 def client_callback(conn, ocsp_data, ignored): 3981 client_calls.append(ocsp_data) 3982 return True 3983 3984 client = self._client_connection(callback=client_callback, data=None) 3985 server = self._server_connection(callback=server_callback, data=None) 3986 handshake_in_memory(client, server) 3987 3988 assert len(client_calls) == 1 3989 assert client_calls[0] == b'' 3990 3991 def test_client_returns_false_terminates_handshake(self): 3992 """ 3993 If the client returns False from its callback, the handshake fails. 3994 """ 3995 def server_callback(*args): 3996 return self.sample_ocsp_data 3997 3998 def client_callback(*args): 3999 return False 4000 4001 client = self._client_connection(callback=client_callback, data=None) 4002 server = self._server_connection(callback=server_callback, data=None) 4003 4004 with pytest.raises(Error): 4005 handshake_in_memory(client, server) 4006 4007 def test_exceptions_in_client_bubble_up(self): 4008 """ 4009 The callbacks thrown in the client callback bubble up to the caller. 4010 """ 4011 class SentinelException(Exception): 4012 pass 4013 4014 def server_callback(*args): 4015 return self.sample_ocsp_data 4016 4017 def client_callback(*args): 4018 raise SentinelException() 4019 4020 client = self._client_connection(callback=client_callback, data=None) 4021 server = self._server_connection(callback=server_callback, data=None) 4022 4023 with pytest.raises(SentinelException): 4024 handshake_in_memory(client, server) 4025 4026 def test_exceptions_in_server_bubble_up(self): 4027 """ 4028 The callbacks thrown in the server callback bubble up to the caller. 4029 """ 4030 class SentinelException(Exception): 4031 pass 4032 4033 def server_callback(*args): 4034 raise SentinelException() 4035 4036 def client_callback(*args): # pragma: nocover 4037 pytest.fail("Should not be called") 4038 4039 client = self._client_connection(callback=client_callback, data=None) 4040 server = self._server_connection(callback=server_callback, data=None) 4041 4042 with pytest.raises(SentinelException): 4043 handshake_in_memory(client, server) 4044 4045 def test_server_must_return_bytes(self): 4046 """ 4047 The server callback must return a bytestring, or a TypeError is thrown. 4048 """ 4049 def server_callback(*args): 4050 return self.sample_ocsp_data.decode('ascii') 4051 4052 def client_callback(*args): # pragma: nocover 4053 pytest.fail("Should not be called") 4054 4055 client = self._client_connection(callback=client_callback, data=None) 4056 server = self._server_connection(callback=server_callback, data=None) 4057 4058 with pytest.raises(TypeError): 4059 handshake_in_memory(client, server) 4060