1# Copyright (C) Jean-Paul Calderone 2# See LICENSE for details. 3 4""" 5Unit tests for :mod:`OpenSSL.SSL`. 6""" 7 8import datetime 9import gc 10import sys 11import uuid 12 13from gc import collect, get_referrers 14from errno import ( 15 EAFNOSUPPORT, 16 ECONNREFUSED, 17 EINPROGRESS, 18 EWOULDBLOCK, 19 EPIPE, 20 ESHUTDOWN, 21) 22from sys import platform, getfilesystemencoding 23from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket 24from os import makedirs 25from os.path import join 26from weakref import ref 27from warnings import simplefilter 28 29import flaky 30 31import pytest 32 33from pretend import raiser 34 35from six import PY2, text_type 36 37from cryptography import x509 38from cryptography.hazmat.backends import default_backend 39from cryptography.hazmat.primitives import hashes 40from cryptography.hazmat.primitives import serialization 41from cryptography.hazmat.primitives.asymmetric import rsa 42from cryptography.x509.oid import NameOID 43 44 45from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM 46from OpenSSL.crypto import PKey, X509, X509Extension, X509Store 47from OpenSSL.crypto import dump_privatekey, load_privatekey 48from OpenSSL.crypto import dump_certificate, load_certificate 49from OpenSSL.crypto import get_elliptic_curves 50 51from OpenSSL.SSL import OPENSSL_VERSION_NUMBER, SSLEAY_VERSION, SSLEAY_CFLAGS 52from OpenSSL.SSL import SSLEAY_PLATFORM, SSLEAY_DIR, SSLEAY_BUILT_ON 53from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN 54from OpenSSL.SSL import ( 55 SSLv2_METHOD, 56 SSLv3_METHOD, 57 SSLv23_METHOD, 58 TLSv1_METHOD, 59 TLSv1_1_METHOD, 60 TLSv1_2_METHOD, 61) 62from OpenSSL.SSL import OP_SINGLE_DH_USE, OP_NO_SSLv2, OP_NO_SSLv3 63from OpenSSL.SSL import ( 64 VERIFY_PEER, 65 VERIFY_FAIL_IF_NO_PEER_CERT, 66 VERIFY_CLIENT_ONCE, 67 VERIFY_NONE, 68) 69 70from OpenSSL import SSL 71from OpenSSL.SSL import ( 72 SESS_CACHE_OFF, 73 SESS_CACHE_CLIENT, 74 SESS_CACHE_SERVER, 75 SESS_CACHE_BOTH, 76 SESS_CACHE_NO_AUTO_CLEAR, 77 SESS_CACHE_NO_INTERNAL_LOOKUP, 78 SESS_CACHE_NO_INTERNAL_STORE, 79 SESS_CACHE_NO_INTERNAL, 80) 81 82from OpenSSL.SSL import ( 83 Error, 84 SysCallError, 85 WantReadError, 86 WantWriteError, 87 ZeroReturnError, 88) 89from OpenSSL.SSL import Context, Session, Connection, SSLeay_version 90from OpenSSL.SSL import _make_requires 91 92from OpenSSL._util import ffi as _ffi, lib as _lib 93 94from OpenSSL.SSL import ( 95 OP_NO_QUERY_MTU, 96 OP_COOKIE_EXCHANGE, 97 OP_NO_TICKET, 98 OP_NO_COMPRESSION, 99 MODE_RELEASE_BUFFERS, 100 NO_OVERLAPPING_PROTOCOLS, 101) 102 103from OpenSSL.SSL import ( 104 SSL_ST_CONNECT, 105 SSL_ST_ACCEPT, 106 SSL_ST_MASK, 107 SSL_CB_LOOP, 108 SSL_CB_EXIT, 109 SSL_CB_READ, 110 SSL_CB_WRITE, 111 SSL_CB_ALERT, 112 SSL_CB_READ_ALERT, 113 SSL_CB_WRITE_ALERT, 114 SSL_CB_ACCEPT_LOOP, 115 SSL_CB_ACCEPT_EXIT, 116 SSL_CB_CONNECT_LOOP, 117 SSL_CB_CONNECT_EXIT, 118 SSL_CB_HANDSHAKE_START, 119 SSL_CB_HANDSHAKE_DONE, 120) 121 122try: 123 from OpenSSL.SSL import ( 124 SSL_ST_INIT, 125 SSL_ST_BEFORE, 126 SSL_ST_OK, 127 SSL_ST_RENEGOTIATE, 128 ) 129except ImportError: 130 SSL_ST_INIT = SSL_ST_BEFORE = SSL_ST_OK = SSL_ST_RENEGOTIATE = None 131 132from .util import WARNING_TYPE_EXPECTED, NON_ASCII, is_consistent_type 133from .test_crypto import ( 134 client_cert_pem, 135 client_key_pem, 136 server_cert_pem, 137 server_key_pem, 138 root_cert_pem, 139 root_key_pem, 140) 141 142 143# openssl dhparam 2048 -out dh-2048.pem 144dhparam = """\ 145-----BEGIN DH PARAMETERS----- 146MIIBCAKCAQEA2F5e976d/GjsaCdKv5RMWL/YV7fq1UUWpPAer5fDXflLMVUuYXxE 1473m3ayZob9lbpgEU0jlPAsXHfQPGxpKmvhv+xV26V/DEoukED8JeZUY/z4pigoptl 148+8+TYdNNE/rFSZQFXIp+v2D91IEgmHBnZlKFSbKR+p8i0KjExXGjU6ji3S5jkOku 149ogikc7df1Ui0hWNJCmTjExq07aXghk97PsdFSxjdawuG3+vos5bnNoUwPLYlFc/z 150ITYG0KXySiCLi4UDlXTZTz7u/+OYczPEgqa/JPUddbM/kfvaRAnjY38cfQ7qXf8Y 151i5s5yYK7a/0eWxxRr2qraYaUj8RwDpH9CwIBAg== 152-----END DH PARAMETERS----- 153""" 154 155 156skip_if_py3 = pytest.mark.skipif(not PY2, reason="Python 2 only") 157 158 159def socket_any_family(): 160 try: 161 return socket(AF_INET) 162 except error as e: 163 if e.errno == EAFNOSUPPORT: 164 return socket(AF_INET6) 165 raise 166 167 168def loopback_address(socket): 169 if socket.family == AF_INET: 170 return "127.0.0.1" 171 else: 172 assert socket.family == AF_INET6 173 return "::1" 174 175 176def join_bytes_or_unicode(prefix, suffix): 177 """ 178 Join two path components of either ``bytes`` or ``unicode``. 179 180 The return type is the same as the type of ``prefix``. 181 """ 182 # If the types are the same, nothing special is necessary. 183 if type(prefix) == type(suffix): 184 return join(prefix, suffix) 185 186 # Otherwise, coerce suffix to the type of prefix. 187 if isinstance(prefix, text_type): 188 return join(prefix, suffix.decode(getfilesystemencoding())) 189 else: 190 return join(prefix, suffix.encode(getfilesystemencoding())) 191 192 193def verify_cb(conn, cert, errnum, depth, ok): 194 return ok 195 196 197def socket_pair(): 198 """ 199 Establish and return a pair of network sockets connected to each other. 200 """ 201 # Connect a pair of sockets 202 port = socket_any_family() 203 port.bind(("", 0)) 204 port.listen(1) 205 client = socket(port.family) 206 client.setblocking(False) 207 client.connect_ex((loopback_address(port), port.getsockname()[1])) 208 client.setblocking(True) 209 server = port.accept()[0] 210 211 # Let's pass some unencrypted data to make sure our socket connection is 212 # fine. Just one byte, so we don't have to worry about buffers getting 213 # filled up or fragmentation. 214 server.send(b"x") 215 assert client.recv(1024) == b"x" 216 client.send(b"y") 217 assert server.recv(1024) == b"y" 218 219 # Most of our callers want non-blocking sockets, make it easy for them. 220 server.setblocking(False) 221 client.setblocking(False) 222 223 return (server, client) 224 225 226def handshake(client, server): 227 conns = [client, server] 228 while conns: 229 for conn in conns: 230 try: 231 conn.do_handshake() 232 except WantReadError: 233 pass 234 else: 235 conns.remove(conn) 236 237 238def _create_certificate_chain(): 239 """ 240 Construct and return a chain of certificates. 241 242 1. A new self-signed certificate authority certificate (cacert) 243 2. A new intermediate certificate signed by cacert (icert) 244 3. A new server certificate signed by icert (scert) 245 """ 246 caext = X509Extension(b"basicConstraints", False, b"CA:true") 247 not_after_date = datetime.date.today() + datetime.timedelta(days=365) 248 not_after = not_after_date.strftime("%Y%m%d%H%M%SZ").encode("ascii") 249 250 # Step 1 251 cakey = PKey() 252 cakey.generate_key(TYPE_RSA, 2048) 253 cacert = X509() 254 cacert.set_version(2) 255 cacert.get_subject().commonName = "Authority Certificate" 256 cacert.set_issuer(cacert.get_subject()) 257 cacert.set_pubkey(cakey) 258 cacert.set_notBefore(b"20000101000000Z") 259 cacert.set_notAfter(not_after) 260 cacert.add_extensions([caext]) 261 cacert.set_serial_number(0) 262 cacert.sign(cakey, "sha256") 263 264 # Step 2 265 ikey = PKey() 266 ikey.generate_key(TYPE_RSA, 2048) 267 icert = X509() 268 icert.set_version(2) 269 icert.get_subject().commonName = "Intermediate Certificate" 270 icert.set_issuer(cacert.get_subject()) 271 icert.set_pubkey(ikey) 272 icert.set_notBefore(b"20000101000000Z") 273 icert.set_notAfter(not_after) 274 icert.add_extensions([caext]) 275 icert.set_serial_number(0) 276 icert.sign(cakey, "sha256") 277 278 # Step 3 279 skey = PKey() 280 skey.generate_key(TYPE_RSA, 2048) 281 scert = X509() 282 scert.set_version(2) 283 scert.get_subject().commonName = "Server Certificate" 284 scert.set_issuer(icert.get_subject()) 285 scert.set_pubkey(skey) 286 scert.set_notBefore(b"20000101000000Z") 287 scert.set_notAfter(not_after) 288 scert.add_extensions( 289 [X509Extension(b"basicConstraints", True, b"CA:false")] 290 ) 291 scert.set_serial_number(0) 292 scert.sign(ikey, "sha256") 293 294 return [(cakey, cacert), (ikey, icert), (skey, scert)] 295 296 297def loopback_client_factory(socket, version=SSLv23_METHOD): 298 client = Connection(Context(version), socket) 299 client.set_connect_state() 300 return client 301 302 303def loopback_server_factory(socket, version=SSLv23_METHOD): 304 ctx = Context(version) 305 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 306 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 307 server = Connection(ctx, socket) 308 server.set_accept_state() 309 return server 310 311 312def loopback(server_factory=None, client_factory=None): 313 """ 314 Create a connected socket pair and force two connected SSL sockets 315 to talk to each other via memory BIOs. 316 """ 317 if server_factory is None: 318 server_factory = loopback_server_factory 319 if client_factory is None: 320 client_factory = loopback_client_factory 321 322 (server, client) = socket_pair() 323 server = server_factory(server) 324 client = client_factory(client) 325 326 handshake(client, server) 327 328 server.setblocking(True) 329 client.setblocking(True) 330 return server, client 331 332 333def interact_in_memory(client_conn, server_conn): 334 """ 335 Try to read application bytes from each of the two `Connection` objects. 336 Copy bytes back and forth between their send/receive buffers for as long 337 as there is anything to copy. When there is nothing more to copy, 338 return `None`. If one of them actually manages to deliver some application 339 bytes, return a two-tuple of the connection from which the bytes were read 340 and the bytes themselves. 341 """ 342 wrote = True 343 while wrote: 344 # Loop until neither side has anything to say 345 wrote = False 346 347 # Copy stuff from each side's send buffer to the other side's 348 # receive buffer. 349 for (read, write) in [ 350 (client_conn, server_conn), 351 (server_conn, client_conn), 352 ]: 353 354 # Give the side a chance to generate some more bytes, or succeed. 355 try: 356 data = read.recv(2 ** 16) 357 except WantReadError: 358 # It didn't succeed, so we'll hope it generated some output. 359 pass 360 else: 361 # It did succeed, so we'll stop now and let the caller deal 362 # with it. 363 return (read, data) 364 365 while True: 366 # Keep copying as long as there's more stuff there. 367 try: 368 dirty = read.bio_read(4096) 369 except WantReadError: 370 # Okay, nothing more waiting to be sent. Stop 371 # processing this send buffer. 372 break 373 else: 374 # Keep track of the fact that someone generated some 375 # output. 376 wrote = True 377 write.bio_write(dirty) 378 379 380def handshake_in_memory(client_conn, server_conn): 381 """ 382 Perform the TLS handshake between two `Connection` instances connected to 383 each other via memory BIOs. 384 """ 385 client_conn.set_connect_state() 386 server_conn.set_accept_state() 387 388 for conn in [client_conn, server_conn]: 389 try: 390 conn.do_handshake() 391 except WantReadError: 392 pass 393 394 interact_in_memory(client_conn, server_conn) 395 396 397class TestVersion(object): 398 """ 399 Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and 400 `OpenSSL.SSL.OPENSSL_VERSION_NUMBER`. 401 """ 402 403 def test_OPENSSL_VERSION_NUMBER(self): 404 """ 405 `OPENSSL_VERSION_NUMBER` is an integer with status in the low byte and 406 the patch, fix, minor, and major versions in the nibbles above that. 407 """ 408 assert isinstance(OPENSSL_VERSION_NUMBER, int) 409 410 def test_SSLeay_version(self): 411 """ 412 `SSLeay_version` takes a version type indicator and returns one of a 413 number of version strings based on that indicator. 414 """ 415 versions = {} 416 for t in [ 417 SSLEAY_VERSION, 418 SSLEAY_CFLAGS, 419 SSLEAY_BUILT_ON, 420 SSLEAY_PLATFORM, 421 SSLEAY_DIR, 422 ]: 423 version = SSLeay_version(t) 424 versions[version] = t 425 assert isinstance(version, bytes) 426 assert len(versions) == 5 427 428 429@pytest.fixture 430def ca_file(tmpdir): 431 """ 432 Create a valid PEM file with CA certificates and return the path. 433 """ 434 key = rsa.generate_private_key( 435 public_exponent=65537, key_size=2048, backend=default_backend() 436 ) 437 public_key = key.public_key() 438 439 builder = x509.CertificateBuilder() 440 builder = builder.subject_name( 441 x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")]) 442 ) 443 builder = builder.issuer_name( 444 x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org")]) 445 ) 446 one_day = datetime.timedelta(1, 0, 0) 447 builder = builder.not_valid_before(datetime.datetime.today() - one_day) 448 builder = builder.not_valid_after(datetime.datetime.today() + one_day) 449 builder = builder.serial_number(int(uuid.uuid4())) 450 builder = builder.public_key(public_key) 451 builder = builder.add_extension( 452 x509.BasicConstraints(ca=True, path_length=None), 453 critical=True, 454 ) 455 456 certificate = builder.sign( 457 private_key=key, algorithm=hashes.SHA256(), backend=default_backend() 458 ) 459 460 ca_file = tmpdir.join("test.pem") 461 ca_file.write_binary( 462 certificate.public_bytes( 463 encoding=serialization.Encoding.PEM, 464 ) 465 ) 466 467 return str(ca_file).encode("ascii") 468 469 470@pytest.fixture 471def context(): 472 """ 473 A simple "best TLS you can get" context. TLS 1.2+ in any reasonable OpenSSL 474 """ 475 return Context(SSLv23_METHOD) 476 477 478class TestContext(object): 479 """ 480 Unit tests for `OpenSSL.SSL.Context`. 481 """ 482 483 @pytest.mark.parametrize( 484 "cipher_string", 485 [b"hello world:AES128-SHA", u"hello world:AES128-SHA"], 486 ) 487 def test_set_cipher_list(self, context, cipher_string): 488 """ 489 `Context.set_cipher_list` accepts both byte and unicode strings 490 for naming the ciphers which connections created with the context 491 object will be able to choose from. 492 """ 493 context.set_cipher_list(cipher_string) 494 conn = Connection(context, None) 495 496 assert "AES128-SHA" in conn.get_cipher_list() 497 498 def test_set_cipher_list_wrong_type(self, context): 499 """ 500 `Context.set_cipher_list` raises `TypeError` when passed a non-string 501 argument. 502 """ 503 with pytest.raises(TypeError): 504 context.set_cipher_list(object()) 505 506 @flaky.flaky 507 def test_set_cipher_list_no_cipher_match(self, context): 508 """ 509 `Context.set_cipher_list` raises `OpenSSL.SSL.Error` with a 510 `"no cipher match"` reason string regardless of the TLS 511 version. 512 """ 513 with pytest.raises(Error) as excinfo: 514 context.set_cipher_list(b"imaginary-cipher") 515 assert excinfo.value.args == ( 516 [ 517 ( 518 "SSL routines", 519 "SSL_CTX_set_cipher_list", 520 "no cipher match", 521 ) 522 ], 523 ) 524 525 def test_load_client_ca(self, context, ca_file): 526 """ 527 `Context.load_client_ca` works as far as we can tell. 528 """ 529 context.load_client_ca(ca_file) 530 531 def test_load_client_ca_invalid(self, context, tmpdir): 532 """ 533 `Context.load_client_ca` raises an Error if the ca file is invalid. 534 """ 535 ca_file = tmpdir.join("test.pem") 536 ca_file.write("") 537 538 with pytest.raises(Error) as e: 539 context.load_client_ca(str(ca_file).encode("ascii")) 540 541 assert "PEM routines" == e.value.args[0][0][0] 542 543 def test_load_client_ca_unicode(self, context, ca_file): 544 """ 545 Passing the path as unicode raises a warning but works. 546 """ 547 pytest.deprecated_call(context.load_client_ca, ca_file.decode("ascii")) 548 549 def test_set_session_id(self, context): 550 """ 551 `Context.set_session_id` works as far as we can tell. 552 """ 553 context.set_session_id(b"abc") 554 555 def test_set_session_id_fail(self, context): 556 """ 557 `Context.set_session_id` errors are propagated. 558 """ 559 with pytest.raises(Error) as e: 560 context.set_session_id(b"abc" * 1000) 561 562 assert [ 563 ( 564 "SSL routines", 565 "SSL_CTX_set_session_id_context", 566 "ssl session id context too long", 567 ) 568 ] == e.value.args[0] 569 570 def test_set_session_id_unicode(self, context): 571 """ 572 `Context.set_session_id` raises a warning if a unicode string is 573 passed. 574 """ 575 pytest.deprecated_call(context.set_session_id, u"abc") 576 577 def test_method(self): 578 """ 579 `Context` can be instantiated with one of `SSLv2_METHOD`, 580 `SSLv3_METHOD`, `SSLv23_METHOD`, `TLSv1_METHOD`, `TLSv1_1_METHOD`, 581 or `TLSv1_2_METHOD`. 582 """ 583 methods = [SSLv23_METHOD, TLSv1_METHOD] 584 for meth in methods: 585 Context(meth) 586 587 maybe = [SSLv2_METHOD, SSLv3_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD] 588 for meth in maybe: 589 try: 590 Context(meth) 591 except (Error, ValueError): 592 # Some versions of OpenSSL have SSLv2 / TLSv1.1 / TLSv1.2, some 593 # don't. Difficult to say in advance. 594 pass 595 596 with pytest.raises(TypeError): 597 Context("") 598 with pytest.raises(ValueError): 599 Context(10) 600 601 def test_type(self): 602 """ 603 `Context` can be used to create instances of that type. 604 """ 605 assert is_consistent_type(Context, "Context", TLSv1_METHOD) 606 607 def test_use_privatekey(self): 608 """ 609 `Context.use_privatekey` takes an `OpenSSL.crypto.PKey` instance. 610 """ 611 key = PKey() 612 key.generate_key(TYPE_RSA, 1024) 613 ctx = Context(SSLv23_METHOD) 614 ctx.use_privatekey(key) 615 with pytest.raises(TypeError): 616 ctx.use_privatekey("") 617 618 def test_use_privatekey_file_missing(self, tmpfile): 619 """ 620 `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` when passed 621 the name of a file which does not exist. 622 """ 623 ctx = Context(SSLv23_METHOD) 624 with pytest.raises(Error): 625 ctx.use_privatekey_file(tmpfile) 626 627 def _use_privatekey_file_test(self, pemfile, filetype): 628 """ 629 Verify that calling ``Context.use_privatekey_file`` with the given 630 arguments does not raise an exception. 631 """ 632 key = PKey() 633 key.generate_key(TYPE_RSA, 1024) 634 635 with open(pemfile, "wt") as pem: 636 pem.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii")) 637 638 ctx = Context(SSLv23_METHOD) 639 ctx.use_privatekey_file(pemfile, filetype) 640 641 @pytest.mark.parametrize("filetype", [object(), "", None, 1.0]) 642 def test_wrong_privatekey_file_wrong_args(self, tmpfile, filetype): 643 """ 644 `Context.use_privatekey_file` raises `TypeError` when called with 645 a `filetype` which is not a valid file encoding. 646 """ 647 ctx = Context(SSLv23_METHOD) 648 with pytest.raises(TypeError): 649 ctx.use_privatekey_file(tmpfile, filetype) 650 651 def test_use_privatekey_file_bytes(self, tmpfile): 652 """ 653 A private key can be specified from a file by passing a ``bytes`` 654 instance giving the file name to ``Context.use_privatekey_file``. 655 """ 656 self._use_privatekey_file_test( 657 tmpfile + NON_ASCII.encode(getfilesystemencoding()), 658 FILETYPE_PEM, 659 ) 660 661 def test_use_privatekey_file_unicode(self, tmpfile): 662 """ 663 A private key can be specified from a file by passing a ``unicode`` 664 instance giving the file name to ``Context.use_privatekey_file``. 665 """ 666 self._use_privatekey_file_test( 667 tmpfile.decode(getfilesystemencoding()) + NON_ASCII, 668 FILETYPE_PEM, 669 ) 670 671 def test_use_certificate_wrong_args(self): 672 """ 673 `Context.use_certificate_wrong_args` raises `TypeError` when not passed 674 exactly one `OpenSSL.crypto.X509` instance as an argument. 675 """ 676 ctx = Context(SSLv23_METHOD) 677 with pytest.raises(TypeError): 678 ctx.use_certificate("hello, world") 679 680 def test_use_certificate_uninitialized(self): 681 """ 682 `Context.use_certificate` raises `OpenSSL.SSL.Error` when passed a 683 `OpenSSL.crypto.X509` instance which has not been initialized 684 (ie, which does not actually have any certificate data). 685 """ 686 ctx = Context(SSLv23_METHOD) 687 with pytest.raises(Error): 688 ctx.use_certificate(X509()) 689 690 def test_use_certificate(self): 691 """ 692 `Context.use_certificate` sets the certificate which will be 693 used to identify connections created using the context. 694 """ 695 # TODO 696 # Hard to assert anything. But we could set a privatekey then ask 697 # OpenSSL if the cert and key agree using check_privatekey. Then as 698 # long as check_privatekey works right we're good... 699 ctx = Context(SSLv23_METHOD) 700 ctx.use_certificate(load_certificate(FILETYPE_PEM, root_cert_pem)) 701 702 def test_use_certificate_file_wrong_args(self): 703 """ 704 `Context.use_certificate_file` raises `TypeError` if the first 705 argument is not a byte string or the second argument is not an integer. 706 """ 707 ctx = Context(SSLv23_METHOD) 708 with pytest.raises(TypeError): 709 ctx.use_certificate_file(object(), FILETYPE_PEM) 710 with pytest.raises(TypeError): 711 ctx.use_certificate_file(b"somefile", object()) 712 with pytest.raises(TypeError): 713 ctx.use_certificate_file(object(), FILETYPE_PEM) 714 715 def test_use_certificate_file_missing(self, tmpfile): 716 """ 717 `Context.use_certificate_file` raises `OpenSSL.SSL.Error` if passed 718 the name of a file which does not exist. 719 """ 720 ctx = Context(SSLv23_METHOD) 721 with pytest.raises(Error): 722 ctx.use_certificate_file(tmpfile) 723 724 def _use_certificate_file_test(self, certificate_file): 725 """ 726 Verify that calling ``Context.use_certificate_file`` with the given 727 filename doesn't raise an exception. 728 """ 729 # TODO 730 # Hard to assert anything. But we could set a privatekey then ask 731 # OpenSSL if the cert and key agree using check_privatekey. Then as 732 # long as check_privatekey works right we're good... 733 with open(certificate_file, "wb") as pem_file: 734 pem_file.write(root_cert_pem) 735 736 ctx = Context(SSLv23_METHOD) 737 ctx.use_certificate_file(certificate_file) 738 739 def test_use_certificate_file_bytes(self, tmpfile): 740 """ 741 `Context.use_certificate_file` sets the certificate (given as a 742 `bytes` filename) which will be used to identify connections created 743 using the context. 744 """ 745 filename = tmpfile + NON_ASCII.encode(getfilesystemencoding()) 746 self._use_certificate_file_test(filename) 747 748 def test_use_certificate_file_unicode(self, tmpfile): 749 """ 750 `Context.use_certificate_file` sets the certificate (given as a 751 `bytes` filename) which will be used to identify connections created 752 using the context. 753 """ 754 filename = tmpfile.decode(getfilesystemencoding()) + NON_ASCII 755 self._use_certificate_file_test(filename) 756 757 def test_check_privatekey_valid(self): 758 """ 759 `Context.check_privatekey` returns `None` if the `Context` instance 760 has been configured to use a matched key and certificate pair. 761 """ 762 key = load_privatekey(FILETYPE_PEM, client_key_pem) 763 cert = load_certificate(FILETYPE_PEM, client_cert_pem) 764 context = Context(SSLv23_METHOD) 765 context.use_privatekey(key) 766 context.use_certificate(cert) 767 assert None is context.check_privatekey() 768 769 def test_check_privatekey_invalid(self): 770 """ 771 `Context.check_privatekey` raises `Error` if the `Context` instance 772 has been configured to use a key and certificate pair which don't 773 relate to each other. 774 """ 775 key = load_privatekey(FILETYPE_PEM, client_key_pem) 776 cert = load_certificate(FILETYPE_PEM, server_cert_pem) 777 context = Context(SSLv23_METHOD) 778 context.use_privatekey(key) 779 context.use_certificate(cert) 780 with pytest.raises(Error): 781 context.check_privatekey() 782 783 def test_app_data(self): 784 """ 785 `Context.set_app_data` stores an object for later retrieval 786 using `Context.get_app_data`. 787 """ 788 app_data = object() 789 context = Context(SSLv23_METHOD) 790 context.set_app_data(app_data) 791 assert context.get_app_data() is app_data 792 793 def test_set_options_wrong_args(self): 794 """ 795 `Context.set_options` raises `TypeError` if called with 796 a non-`int` argument. 797 """ 798 context = Context(SSLv23_METHOD) 799 with pytest.raises(TypeError): 800 context.set_options(None) 801 802 def test_set_options(self): 803 """ 804 `Context.set_options` returns the new options value. 805 """ 806 context = Context(SSLv23_METHOD) 807 options = context.set_options(OP_NO_SSLv2) 808 assert options & OP_NO_SSLv2 == OP_NO_SSLv2 809 810 def test_set_mode_wrong_args(self): 811 """ 812 `Context.set_mode` raises `TypeError` if called with 813 a non-`int` argument. 814 """ 815 context = Context(SSLv23_METHOD) 816 with pytest.raises(TypeError): 817 context.set_mode(None) 818 819 def test_set_mode(self): 820 """ 821 `Context.set_mode` accepts a mode bitvector and returns the 822 newly set mode. 823 """ 824 context = Context(SSLv23_METHOD) 825 assert MODE_RELEASE_BUFFERS & context.set_mode(MODE_RELEASE_BUFFERS) 826 827 def test_set_timeout_wrong_args(self): 828 """ 829 `Context.set_timeout` raises `TypeError` if called with 830 a non-`int` argument. 831 """ 832 context = Context(SSLv23_METHOD) 833 with pytest.raises(TypeError): 834 context.set_timeout(None) 835 836 def test_timeout(self): 837 """ 838 `Context.set_timeout` sets the session timeout for all connections 839 created using the context object. `Context.get_timeout` retrieves 840 this value. 841 """ 842 context = Context(SSLv23_METHOD) 843 context.set_timeout(1234) 844 assert context.get_timeout() == 1234 845 846 def test_set_verify_depth_wrong_args(self): 847 """ 848 `Context.set_verify_depth` raises `TypeError` if called with a 849 non-`int` argument. 850 """ 851 context = Context(SSLv23_METHOD) 852 with pytest.raises(TypeError): 853 context.set_verify_depth(None) 854 855 def test_verify_depth(self): 856 """ 857 `Context.set_verify_depth` sets the number of certificates in 858 a chain to follow before giving up. The value can be retrieved with 859 `Context.get_verify_depth`. 860 """ 861 context = Context(SSLv23_METHOD) 862 context.set_verify_depth(11) 863 assert context.get_verify_depth() == 11 864 865 def _write_encrypted_pem(self, passphrase, tmpfile): 866 """ 867 Write a new private key out to a new file, encrypted using the given 868 passphrase. Return the path to the new file. 869 """ 870 key = PKey() 871 key.generate_key(TYPE_RSA, 1024) 872 pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase) 873 with open(tmpfile, "w") as fObj: 874 fObj.write(pem.decode("ascii")) 875 return tmpfile 876 877 def test_set_passwd_cb_wrong_args(self): 878 """ 879 `Context.set_passwd_cb` raises `TypeError` if called with a 880 non-callable first argument. 881 """ 882 context = Context(SSLv23_METHOD) 883 with pytest.raises(TypeError): 884 context.set_passwd_cb(None) 885 886 def test_set_passwd_cb(self, tmpfile): 887 """ 888 `Context.set_passwd_cb` accepts a callable which will be invoked when 889 a private key is loaded from an encrypted PEM. 890 """ 891 passphrase = b"foobar" 892 pemFile = self._write_encrypted_pem(passphrase, tmpfile) 893 calledWith = [] 894 895 def passphraseCallback(maxlen, verify, extra): 896 calledWith.append((maxlen, verify, extra)) 897 return passphrase 898 899 context = Context(SSLv23_METHOD) 900 context.set_passwd_cb(passphraseCallback) 901 context.use_privatekey_file(pemFile) 902 assert len(calledWith) == 1 903 assert isinstance(calledWith[0][0], int) 904 assert isinstance(calledWith[0][1], int) 905 assert calledWith[0][2] is None 906 907 def test_passwd_callback_exception(self, tmpfile): 908 """ 909 `Context.use_privatekey_file` propagates any exception raised 910 by the passphrase callback. 911 """ 912 pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) 913 914 def passphraseCallback(maxlen, verify, extra): 915 raise RuntimeError("Sorry, I am a fail.") 916 917 context = Context(SSLv23_METHOD) 918 context.set_passwd_cb(passphraseCallback) 919 with pytest.raises(RuntimeError): 920 context.use_privatekey_file(pemFile) 921 922 def test_passwd_callback_false(self, tmpfile): 923 """ 924 `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the 925 passphrase callback returns a false value. 926 """ 927 pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) 928 929 def passphraseCallback(maxlen, verify, extra): 930 return b"" 931 932 context = Context(SSLv23_METHOD) 933 context.set_passwd_cb(passphraseCallback) 934 with pytest.raises(Error): 935 context.use_privatekey_file(pemFile) 936 937 def test_passwd_callback_non_string(self, tmpfile): 938 """ 939 `Context.use_privatekey_file` raises `OpenSSL.SSL.Error` if the 940 passphrase callback returns a true non-string value. 941 """ 942 pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) 943 944 def passphraseCallback(maxlen, verify, extra): 945 return 10 946 947 context = Context(SSLv23_METHOD) 948 context.set_passwd_cb(passphraseCallback) 949 # TODO: Surely this is the wrong error? 950 with pytest.raises(ValueError): 951 context.use_privatekey_file(pemFile) 952 953 def test_passwd_callback_too_long(self, tmpfile): 954 """ 955 If the passphrase returned by the passphrase callback returns a string 956 longer than the indicated maximum length, it is truncated. 957 """ 958 # A priori knowledge! 959 passphrase = b"x" * 1024 960 pemFile = self._write_encrypted_pem(passphrase, tmpfile) 961 962 def passphraseCallback(maxlen, verify, extra): 963 assert maxlen == 1024 964 return passphrase + b"y" 965 966 context = Context(SSLv23_METHOD) 967 context.set_passwd_cb(passphraseCallback) 968 # This shall succeed because the truncated result is the correct 969 # passphrase. 970 context.use_privatekey_file(pemFile) 971 972 def test_set_info_callback(self): 973 """ 974 `Context.set_info_callback` accepts a callable which will be 975 invoked when certain information about an SSL connection is available. 976 """ 977 (server, client) = socket_pair() 978 979 clientSSL = Connection(Context(SSLv23_METHOD), client) 980 clientSSL.set_connect_state() 981 982 called = [] 983 984 def info(conn, where, ret): 985 called.append((conn, where, ret)) 986 987 context = Context(SSLv23_METHOD) 988 context.set_info_callback(info) 989 context.use_certificate(load_certificate(FILETYPE_PEM, root_cert_pem)) 990 context.use_privatekey(load_privatekey(FILETYPE_PEM, root_key_pem)) 991 992 serverSSL = Connection(context, server) 993 serverSSL.set_accept_state() 994 995 handshake(clientSSL, serverSSL) 996 997 # The callback must always be called with a Connection instance as the 998 # first argument. It would probably be better to split this into 999 # separate tests for client and server side info callbacks so we could 1000 # assert it is called with the right Connection instance. It would 1001 # also be good to assert *something* about `where` and `ret`. 1002 notConnections = [ 1003 conn 1004 for (conn, where, ret) in called 1005 if not isinstance(conn, Connection) 1006 ] 1007 assert ( 1008 [] == notConnections 1009 ), "Some info callback arguments were not Connection instances." 1010 1011 @pytest.mark.skipif( 1012 not getattr(_lib, "Cryptography_HAS_KEYLOG", None), 1013 reason="SSL_CTX_set_keylog_callback unavailable", 1014 ) 1015 def test_set_keylog_callback(self): 1016 """ 1017 `Context.set_keylog_callback` accepts a callable which will be 1018 invoked when key material is generated or received. 1019 """ 1020 called = [] 1021 1022 def keylog(conn, line): 1023 called.append((conn, line)) 1024 1025 server_context = Context(TLSv1_2_METHOD) 1026 server_context.set_keylog_callback(keylog) 1027 server_context.use_certificate( 1028 load_certificate(FILETYPE_PEM, root_cert_pem) 1029 ) 1030 server_context.use_privatekey( 1031 load_privatekey(FILETYPE_PEM, root_key_pem) 1032 ) 1033 1034 client_context = Context(SSLv23_METHOD) 1035 1036 self._handshake_test(server_context, client_context) 1037 1038 assert called 1039 assert all(isinstance(conn, Connection) for conn, line in called) 1040 assert all(b"CLIENT_RANDOM" in line for conn, line in called) 1041 1042 def _load_verify_locations_test(self, *args): 1043 """ 1044 Create a client context which will verify the peer certificate and call 1045 its `load_verify_locations` method with the given arguments. 1046 Then connect it to a server and ensure that the handshake succeeds. 1047 """ 1048 (server, client) = socket_pair() 1049 1050 clientContext = Context(SSLv23_METHOD) 1051 clientContext.load_verify_locations(*args) 1052 # Require that the server certificate verify properly or the 1053 # connection will fail. 1054 clientContext.set_verify( 1055 VERIFY_PEER, 1056 lambda conn, cert, errno, depth, preverify_ok: preverify_ok, 1057 ) 1058 1059 clientSSL = Connection(clientContext, client) 1060 clientSSL.set_connect_state() 1061 1062 serverContext = Context(SSLv23_METHOD) 1063 serverContext.use_certificate( 1064 load_certificate(FILETYPE_PEM, root_cert_pem) 1065 ) 1066 serverContext.use_privatekey( 1067 load_privatekey(FILETYPE_PEM, root_key_pem) 1068 ) 1069 1070 serverSSL = Connection(serverContext, server) 1071 serverSSL.set_accept_state() 1072 1073 # Without load_verify_locations above, the handshake 1074 # will fail: 1075 # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE', 1076 # 'certificate verify failed')] 1077 handshake(clientSSL, serverSSL) 1078 1079 cert = clientSSL.get_peer_certificate() 1080 assert cert.get_subject().CN == "Testing Root CA" 1081 1082 def _load_verify_cafile(self, cafile): 1083 """ 1084 Verify that if path to a file containing a certificate is passed to 1085 `Context.load_verify_locations` for the ``cafile`` parameter, that 1086 certificate is used as a trust root for the purposes of verifying 1087 connections created using that `Context`. 1088 """ 1089 with open(cafile, "w") as fObj: 1090 fObj.write(root_cert_pem.decode("ascii")) 1091 1092 self._load_verify_locations_test(cafile) 1093 1094 def test_load_verify_bytes_cafile(self, tmpfile): 1095 """ 1096 `Context.load_verify_locations` accepts a file name as a `bytes` 1097 instance and uses the certificates within for verification purposes. 1098 """ 1099 cafile = tmpfile + NON_ASCII.encode(getfilesystemencoding()) 1100 self._load_verify_cafile(cafile) 1101 1102 def test_load_verify_unicode_cafile(self, tmpfile): 1103 """ 1104 `Context.load_verify_locations` accepts a file name as a `unicode` 1105 instance and uses the certificates within for verification purposes. 1106 """ 1107 self._load_verify_cafile( 1108 tmpfile.decode(getfilesystemencoding()) + NON_ASCII 1109 ) 1110 1111 def test_load_verify_invalid_file(self, tmpfile): 1112 """ 1113 `Context.load_verify_locations` raises `Error` when passed a 1114 non-existent cafile. 1115 """ 1116 clientContext = Context(SSLv23_METHOD) 1117 with pytest.raises(Error): 1118 clientContext.load_verify_locations(tmpfile) 1119 1120 def _load_verify_directory_locations_capath(self, capath): 1121 """ 1122 Verify that if path to a directory containing certificate files is 1123 passed to ``Context.load_verify_locations`` for the ``capath`` 1124 parameter, those certificates are used as trust roots for the purposes 1125 of verifying connections created using that ``Context``. 1126 """ 1127 makedirs(capath) 1128 # Hash values computed manually with c_rehash to avoid depending on 1129 # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other 1130 # from OpenSSL 1.0.0. 1131 for name in [b"c7adac82.0", b"c3705638.0"]: 1132 cafile = join_bytes_or_unicode(capath, name) 1133 with open(cafile, "w") as fObj: 1134 fObj.write(root_cert_pem.decode("ascii")) 1135 1136 self._load_verify_locations_test(None, capath) 1137 1138 def test_load_verify_directory_bytes_capath(self, tmpfile): 1139 """ 1140 `Context.load_verify_locations` accepts a directory name as a `bytes` 1141 instance and uses the certificates within for verification purposes. 1142 """ 1143 self._load_verify_directory_locations_capath( 1144 tmpfile + NON_ASCII.encode(getfilesystemencoding()) 1145 ) 1146 1147 def test_load_verify_directory_unicode_capath(self, tmpfile): 1148 """ 1149 `Context.load_verify_locations` accepts a directory name as a `unicode` 1150 instance and uses the certificates within for verification purposes. 1151 """ 1152 self._load_verify_directory_locations_capath( 1153 tmpfile.decode(getfilesystemencoding()) + NON_ASCII 1154 ) 1155 1156 def test_load_verify_locations_wrong_args(self): 1157 """ 1158 `Context.load_verify_locations` raises `TypeError` if with non-`str` 1159 arguments. 1160 """ 1161 context = Context(SSLv23_METHOD) 1162 with pytest.raises(TypeError): 1163 context.load_verify_locations(object()) 1164 with pytest.raises(TypeError): 1165 context.load_verify_locations(object(), object()) 1166 1167 @pytest.mark.skipif( 1168 not platform.startswith("linux"), 1169 reason="Loading fallback paths is a linux-specific behavior to " 1170 "accommodate pyca/cryptography manylinux1 wheels", 1171 ) 1172 def test_fallback_default_verify_paths(self, monkeypatch): 1173 """ 1174 Test that we load certificates successfully on linux from the fallback 1175 path. To do this we set the _CRYPTOGRAPHY_MANYLINUX1_CA_FILE and 1176 _CRYPTOGRAPHY_MANYLINUX1_CA_DIR vars to be equal to whatever the 1177 current OpenSSL default is and we disable 1178 SSL_CTX_SET_default_verify_paths so that it can't find certs unless 1179 it loads via fallback. 1180 """ 1181 context = Context(SSLv23_METHOD) 1182 monkeypatch.setattr( 1183 _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 1184 ) 1185 monkeypatch.setattr( 1186 SSL, 1187 "_CRYPTOGRAPHY_MANYLINUX1_CA_FILE", 1188 _ffi.string(_lib.X509_get_default_cert_file()), 1189 ) 1190 monkeypatch.setattr( 1191 SSL, 1192 "_CRYPTOGRAPHY_MANYLINUX1_CA_DIR", 1193 _ffi.string(_lib.X509_get_default_cert_dir()), 1194 ) 1195 context.set_default_verify_paths() 1196 store = context.get_cert_store() 1197 sk_obj = _lib.X509_STORE_get0_objects(store._store) 1198 assert sk_obj != _ffi.NULL 1199 num = _lib.sk_X509_OBJECT_num(sk_obj) 1200 assert num != 0 1201 1202 def test_check_env_vars(self, monkeypatch): 1203 """ 1204 Test that we return True/False appropriately if the env vars are set. 1205 """ 1206 context = Context(SSLv23_METHOD) 1207 dir_var = "CUSTOM_DIR_VAR" 1208 file_var = "CUSTOM_FILE_VAR" 1209 assert context._check_env_vars_set(dir_var, file_var) is False 1210 monkeypatch.setenv(dir_var, "value") 1211 monkeypatch.setenv(file_var, "value") 1212 assert context._check_env_vars_set(dir_var, file_var) is True 1213 assert context._check_env_vars_set(dir_var, file_var) is True 1214 1215 def test_verify_no_fallback_if_env_vars_set(self, monkeypatch): 1216 """ 1217 Test that we don't use the fallback path if env vars are set. 1218 """ 1219 context = Context(SSLv23_METHOD) 1220 monkeypatch.setattr( 1221 _lib, "SSL_CTX_set_default_verify_paths", lambda x: 1 1222 ) 1223 dir_env_var = _ffi.string(_lib.X509_get_default_cert_dir_env()).decode( 1224 "ascii" 1225 ) 1226 file_env_var = _ffi.string( 1227 _lib.X509_get_default_cert_file_env() 1228 ).decode("ascii") 1229 monkeypatch.setenv(dir_env_var, "value") 1230 monkeypatch.setenv(file_env_var, "value") 1231 context.set_default_verify_paths() 1232 1233 monkeypatch.setattr( 1234 context, "_fallback_default_verify_paths", raiser(SystemError) 1235 ) 1236 context.set_default_verify_paths() 1237 1238 @pytest.mark.skipif( 1239 platform == "win32", 1240 reason="set_default_verify_paths appears not to work on Windows. " 1241 "See LP#404343 and LP#404344.", 1242 ) 1243 def test_set_default_verify_paths(self): 1244 """ 1245 `Context.set_default_verify_paths` causes the platform-specific CA 1246 certificate locations to be used for verification purposes. 1247 """ 1248 # Testing this requires a server with a certificate signed by one 1249 # of the CAs in the platform CA location. Getting one of those 1250 # costs money. Fortunately (or unfortunately, depending on your 1251 # perspective), it's easy to think of a public server on the 1252 # internet which has such a certificate. Connecting to the network 1253 # in a unit test is bad, but it's the only way I can think of to 1254 # really test this. -exarkun 1255 context = Context(SSLv23_METHOD) 1256 context.set_default_verify_paths() 1257 context.set_verify( 1258 VERIFY_PEER, 1259 lambda conn, cert, errno, depth, preverify_ok: preverify_ok, 1260 ) 1261 1262 client = socket_any_family() 1263 client.connect(("encrypted.google.com", 443)) 1264 clientSSL = Connection(context, client) 1265 clientSSL.set_connect_state() 1266 clientSSL.set_tlsext_host_name(b"encrypted.google.com") 1267 clientSSL.do_handshake() 1268 clientSSL.send(b"GET / HTTP/1.0\r\n\r\n") 1269 assert clientSSL.recv(1024) 1270 1271 def test_fallback_path_is_not_file_or_dir(self): 1272 """ 1273 Test that when passed empty arrays or paths that do not exist no 1274 errors are raised. 1275 """ 1276 context = Context(SSLv23_METHOD) 1277 context._fallback_default_verify_paths([], []) 1278 context._fallback_default_verify_paths(["/not/a/file"], ["/not/a/dir"]) 1279 1280 def test_add_extra_chain_cert_invalid_cert(self): 1281 """ 1282 `Context.add_extra_chain_cert` raises `TypeError` if called with an 1283 object which is not an instance of `X509`. 1284 """ 1285 context = Context(SSLv23_METHOD) 1286 with pytest.raises(TypeError): 1287 context.add_extra_chain_cert(object()) 1288 1289 def _handshake_test(self, serverContext, clientContext): 1290 """ 1291 Verify that a client and server created with the given contexts can 1292 successfully handshake and communicate. 1293 """ 1294 serverSocket, clientSocket = socket_pair() 1295 1296 server = Connection(serverContext, serverSocket) 1297 server.set_accept_state() 1298 1299 client = Connection(clientContext, clientSocket) 1300 client.set_connect_state() 1301 1302 # Make them talk to each other. 1303 # interact_in_memory(client, server) 1304 for _ in range(3): 1305 for s in [client, server]: 1306 try: 1307 s.do_handshake() 1308 except WantReadError: 1309 pass 1310 1311 def test_set_verify_callback_connection_argument(self): 1312 """ 1313 The first argument passed to the verify callback is the 1314 `Connection` instance for which verification is taking place. 1315 """ 1316 serverContext = Context(SSLv23_METHOD) 1317 serverContext.use_privatekey( 1318 load_privatekey(FILETYPE_PEM, root_key_pem) 1319 ) 1320 serverContext.use_certificate( 1321 load_certificate(FILETYPE_PEM, root_cert_pem) 1322 ) 1323 serverConnection = Connection(serverContext, None) 1324 1325 class VerifyCallback(object): 1326 def callback(self, connection, *args): 1327 self.connection = connection 1328 return 1 1329 1330 verify = VerifyCallback() 1331 clientContext = Context(SSLv23_METHOD) 1332 clientContext.set_verify(VERIFY_PEER, verify.callback) 1333 clientConnection = Connection(clientContext, None) 1334 clientConnection.set_connect_state() 1335 1336 handshake_in_memory(clientConnection, serverConnection) 1337 1338 assert verify.connection is clientConnection 1339 1340 def test_x509_in_verify_works(self): 1341 """ 1342 We had a bug where the X509 cert instantiated in the callback wrapper 1343 didn't __init__ so it was missing objects needed when calling 1344 get_subject. This test sets up a handshake where we call get_subject 1345 on the cert provided to the verify callback. 1346 """ 1347 serverContext = Context(SSLv23_METHOD) 1348 serverContext.use_privatekey( 1349 load_privatekey(FILETYPE_PEM, root_key_pem) 1350 ) 1351 serverContext.use_certificate( 1352 load_certificate(FILETYPE_PEM, root_cert_pem) 1353 ) 1354 serverConnection = Connection(serverContext, None) 1355 1356 def verify_cb_get_subject(conn, cert, errnum, depth, ok): 1357 assert cert.get_subject() 1358 return 1 1359 1360 clientContext = Context(SSLv23_METHOD) 1361 clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject) 1362 clientConnection = Connection(clientContext, None) 1363 clientConnection.set_connect_state() 1364 1365 handshake_in_memory(clientConnection, serverConnection) 1366 1367 def test_set_verify_callback_exception(self): 1368 """ 1369 If the verify callback passed to `Context.set_verify` raises an 1370 exception, verification fails and the exception is propagated to the 1371 caller of `Connection.do_handshake`. 1372 """ 1373 serverContext = Context(TLSv1_2_METHOD) 1374 serverContext.use_privatekey( 1375 load_privatekey(FILETYPE_PEM, root_key_pem) 1376 ) 1377 serverContext.use_certificate( 1378 load_certificate(FILETYPE_PEM, root_cert_pem) 1379 ) 1380 1381 clientContext = Context(TLSv1_2_METHOD) 1382 1383 def verify_callback(*args): 1384 raise Exception("silly verify failure") 1385 1386 clientContext.set_verify(VERIFY_PEER, verify_callback) 1387 1388 with pytest.raises(Exception) as exc: 1389 self._handshake_test(serverContext, clientContext) 1390 1391 assert "silly verify failure" == str(exc.value) 1392 1393 def test_set_verify_callback_reference(self): 1394 """ 1395 If the verify callback passed to `Context.set_verify` is set multiple 1396 times, the pointers to the old call functions should not be dangling 1397 and trigger a segfault. 1398 """ 1399 serverContext = Context(TLSv1_2_METHOD) 1400 serverContext.use_privatekey( 1401 load_privatekey(FILETYPE_PEM, root_key_pem) 1402 ) 1403 serverContext.use_certificate( 1404 load_certificate(FILETYPE_PEM, root_cert_pem) 1405 ) 1406 1407 clientContext = Context(TLSv1_2_METHOD) 1408 1409 clients = [] 1410 1411 for i in range(5): 1412 1413 def verify_callback(*args): 1414 return True 1415 1416 serverSocket, clientSocket = socket_pair() 1417 client = Connection(clientContext, clientSocket) 1418 1419 clients.append((serverSocket, client)) 1420 1421 clientContext.set_verify(VERIFY_PEER, verify_callback) 1422 1423 gc.collect() 1424 1425 # Make them talk to each other. 1426 for serverSocket, client in clients: 1427 server = Connection(serverContext, serverSocket) 1428 server.set_accept_state() 1429 client.set_connect_state() 1430 1431 for _ in range(5): 1432 for s in [client, server]: 1433 try: 1434 s.do_handshake() 1435 except WantReadError: 1436 pass 1437 1438 @pytest.mark.parametrize("mode", [SSL.VERIFY_PEER, SSL.VERIFY_NONE]) 1439 def test_set_verify_default_callback(self, mode): 1440 """ 1441 If the verify callback is omitted, the preverify value is used. 1442 """ 1443 serverContext = Context(TLSv1_2_METHOD) 1444 serverContext.use_privatekey( 1445 load_privatekey(FILETYPE_PEM, root_key_pem) 1446 ) 1447 serverContext.use_certificate( 1448 load_certificate(FILETYPE_PEM, root_cert_pem) 1449 ) 1450 1451 clientContext = Context(TLSv1_2_METHOD) 1452 clientContext.set_verify(mode, None) 1453 1454 if mode == SSL.VERIFY_PEER: 1455 with pytest.raises(Exception) as exc: 1456 self._handshake_test(serverContext, clientContext) 1457 assert "certificate verify failed" in str(exc.value) 1458 else: 1459 self._handshake_test(serverContext, clientContext) 1460 1461 def test_add_extra_chain_cert(self, tmpdir): 1462 """ 1463 `Context.add_extra_chain_cert` accepts an `X509` 1464 instance to add to the certificate chain. 1465 1466 See `_create_certificate_chain` for the details of the 1467 certificate chain tested. 1468 1469 The chain is tested by starting a server with scert and connecting 1470 to it with a client which trusts cacert and requires verification to 1471 succeed. 1472 """ 1473 chain = _create_certificate_chain() 1474 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 1475 1476 # Dump the CA certificate to a file because that's the only way to load 1477 # it as a trusted CA in the client context. 1478 for cert, name in [ 1479 (cacert, "ca.pem"), 1480 (icert, "i.pem"), 1481 (scert, "s.pem"), 1482 ]: 1483 with tmpdir.join(name).open("w") as f: 1484 f.write(dump_certificate(FILETYPE_PEM, cert).decode("ascii")) 1485 1486 for key, name in [(cakey, "ca.key"), (ikey, "i.key"), (skey, "s.key")]: 1487 with tmpdir.join(name).open("w") as f: 1488 f.write(dump_privatekey(FILETYPE_PEM, key).decode("ascii")) 1489 1490 # Create the server context 1491 serverContext = Context(SSLv23_METHOD) 1492 serverContext.use_privatekey(skey) 1493 serverContext.use_certificate(scert) 1494 # The client already has cacert, we only need to give them icert. 1495 serverContext.add_extra_chain_cert(icert) 1496 1497 # Create the client 1498 clientContext = Context(SSLv23_METHOD) 1499 clientContext.set_verify( 1500 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb 1501 ) 1502 clientContext.load_verify_locations(str(tmpdir.join("ca.pem"))) 1503 1504 # Try it out. 1505 self._handshake_test(serverContext, clientContext) 1506 1507 def _use_certificate_chain_file_test(self, certdir): 1508 """ 1509 Verify that `Context.use_certificate_chain_file` reads a 1510 certificate chain from a specified file. 1511 1512 The chain is tested by starting a server with scert and connecting to 1513 it with a client which trusts cacert and requires verification to 1514 succeed. 1515 """ 1516 chain = _create_certificate_chain() 1517 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 1518 1519 makedirs(certdir) 1520 1521 chainFile = join_bytes_or_unicode(certdir, "chain.pem") 1522 caFile = join_bytes_or_unicode(certdir, "ca.pem") 1523 1524 # Write out the chain file. 1525 with open(chainFile, "wb") as fObj: 1526 # Most specific to least general. 1527 fObj.write(dump_certificate(FILETYPE_PEM, scert)) 1528 fObj.write(dump_certificate(FILETYPE_PEM, icert)) 1529 fObj.write(dump_certificate(FILETYPE_PEM, cacert)) 1530 1531 with open(caFile, "w") as fObj: 1532 fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode("ascii")) 1533 1534 serverContext = Context(SSLv23_METHOD) 1535 serverContext.use_certificate_chain_file(chainFile) 1536 serverContext.use_privatekey(skey) 1537 1538 clientContext = Context(SSLv23_METHOD) 1539 clientContext.set_verify( 1540 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb 1541 ) 1542 clientContext.load_verify_locations(caFile) 1543 1544 self._handshake_test(serverContext, clientContext) 1545 1546 def test_use_certificate_chain_file_bytes(self, tmpfile): 1547 """ 1548 ``Context.use_certificate_chain_file`` accepts the name of a file (as 1549 an instance of ``bytes``) to specify additional certificates to use to 1550 construct and verify a trust chain. 1551 """ 1552 self._use_certificate_chain_file_test( 1553 tmpfile + NON_ASCII.encode(getfilesystemencoding()) 1554 ) 1555 1556 def test_use_certificate_chain_file_unicode(self, tmpfile): 1557 """ 1558 ``Context.use_certificate_chain_file`` accepts the name of a file (as 1559 an instance of ``unicode``) to specify additional certificates to use 1560 to construct and verify a trust chain. 1561 """ 1562 self._use_certificate_chain_file_test( 1563 tmpfile.decode(getfilesystemencoding()) + NON_ASCII 1564 ) 1565 1566 def test_use_certificate_chain_file_wrong_args(self): 1567 """ 1568 `Context.use_certificate_chain_file` raises `TypeError` if passed a 1569 non-byte string single argument. 1570 """ 1571 context = Context(SSLv23_METHOD) 1572 with pytest.raises(TypeError): 1573 context.use_certificate_chain_file(object()) 1574 1575 def test_use_certificate_chain_file_missing_file(self, tmpfile): 1576 """ 1577 `Context.use_certificate_chain_file` raises `OpenSSL.SSL.Error` when 1578 passed a bad chain file name (for example, the name of a file which 1579 does not exist). 1580 """ 1581 context = Context(SSLv23_METHOD) 1582 with pytest.raises(Error): 1583 context.use_certificate_chain_file(tmpfile) 1584 1585 def test_set_verify_mode(self): 1586 """ 1587 `Context.get_verify_mode` returns the verify mode flags previously 1588 passed to `Context.set_verify`. 1589 """ 1590 context = Context(SSLv23_METHOD) 1591 assert context.get_verify_mode() == 0 1592 context.set_verify(VERIFY_PEER | VERIFY_CLIENT_ONCE) 1593 assert context.get_verify_mode() == (VERIFY_PEER | VERIFY_CLIENT_ONCE) 1594 1595 @pytest.mark.parametrize("mode", [None, 1.0, object(), "mode"]) 1596 def test_set_verify_wrong_mode_arg(self, mode): 1597 """ 1598 `Context.set_verify` raises `TypeError` if the first argument is 1599 not an integer. 1600 """ 1601 context = Context(SSLv23_METHOD) 1602 with pytest.raises(TypeError): 1603 context.set_verify(mode=mode) 1604 1605 @pytest.mark.parametrize("callback", [1.0, "mode", ("foo", "bar")]) 1606 def test_set_verify_wrong_callable_arg(self, callback): 1607 """ 1608 `Context.set_verify` raises `TypeError` if the second argument 1609 is not callable. 1610 """ 1611 context = Context(SSLv23_METHOD) 1612 with pytest.raises(TypeError): 1613 context.set_verify(mode=VERIFY_PEER, callback=callback) 1614 1615 def test_load_tmp_dh_wrong_args(self): 1616 """ 1617 `Context.load_tmp_dh` raises `TypeError` if called with a 1618 non-`str` argument. 1619 """ 1620 context = Context(SSLv23_METHOD) 1621 with pytest.raises(TypeError): 1622 context.load_tmp_dh(object()) 1623 1624 def test_load_tmp_dh_missing_file(self): 1625 """ 1626 `Context.load_tmp_dh` raises `OpenSSL.SSL.Error` if the 1627 specified file does not exist. 1628 """ 1629 context = Context(SSLv23_METHOD) 1630 with pytest.raises(Error): 1631 context.load_tmp_dh(b"hello") 1632 1633 def _load_tmp_dh_test(self, dhfilename): 1634 """ 1635 Verify that calling ``Context.load_tmp_dh`` with the given filename 1636 does not raise an exception. 1637 """ 1638 context = Context(SSLv23_METHOD) 1639 with open(dhfilename, "w") as dhfile: 1640 dhfile.write(dhparam) 1641 1642 context.load_tmp_dh(dhfilename) 1643 1644 def test_load_tmp_dh_bytes(self, tmpfile): 1645 """ 1646 `Context.load_tmp_dh` loads Diffie-Hellman parameters from the 1647 specified file (given as ``bytes``). 1648 """ 1649 self._load_tmp_dh_test( 1650 tmpfile + NON_ASCII.encode(getfilesystemencoding()), 1651 ) 1652 1653 def test_load_tmp_dh_unicode(self, tmpfile): 1654 """ 1655 `Context.load_tmp_dh` loads Diffie-Hellman parameters from the 1656 specified file (given as ``unicode``). 1657 """ 1658 self._load_tmp_dh_test( 1659 tmpfile.decode(getfilesystemencoding()) + NON_ASCII, 1660 ) 1661 1662 def test_set_tmp_ecdh(self): 1663 """ 1664 `Context.set_tmp_ecdh` sets the elliptic curve for Diffie-Hellman to 1665 the specified curve. 1666 """ 1667 context = Context(SSLv23_METHOD) 1668 for curve in get_elliptic_curves(): 1669 if curve.name.startswith(u"Oakley-"): 1670 # Setting Oakley-EC2N-4 and Oakley-EC2N-3 adds 1671 # ('bignum routines', 'BN_mod_inverse', 'no inverse') to the 1672 # error queue on OpenSSL 1.0.2. 1673 continue 1674 # The only easily "assertable" thing is that it does not raise an 1675 # exception. 1676 context.set_tmp_ecdh(curve) 1677 1678 def test_set_session_cache_mode_wrong_args(self): 1679 """ 1680 `Context.set_session_cache_mode` raises `TypeError` if called with 1681 a non-integer argument. 1682 called with other than one integer argument. 1683 """ 1684 context = Context(SSLv23_METHOD) 1685 with pytest.raises(TypeError): 1686 context.set_session_cache_mode(object()) 1687 1688 def test_session_cache_mode(self): 1689 """ 1690 `Context.set_session_cache_mode` specifies how sessions are cached. 1691 The setting can be retrieved via `Context.get_session_cache_mode`. 1692 """ 1693 context = Context(SSLv23_METHOD) 1694 context.set_session_cache_mode(SESS_CACHE_OFF) 1695 off = context.set_session_cache_mode(SESS_CACHE_BOTH) 1696 assert SESS_CACHE_OFF == off 1697 assert SESS_CACHE_BOTH == context.get_session_cache_mode() 1698 1699 def test_get_cert_store(self): 1700 """ 1701 `Context.get_cert_store` returns a `X509Store` instance. 1702 """ 1703 context = Context(SSLv23_METHOD) 1704 store = context.get_cert_store() 1705 assert isinstance(store, X509Store) 1706 1707 def test_set_tlsext_use_srtp_not_bytes(self): 1708 """ 1709 `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material. 1710 1711 It raises a TypeError if the list of profiles is not a byte string. 1712 """ 1713 context = Context(SSLv23_METHOD) 1714 with pytest.raises(TypeError): 1715 context.set_tlsext_use_srtp(text_type("SRTP_AES128_CM_SHA1_80")) 1716 1717 def test_set_tlsext_use_srtp_invalid_profile(self): 1718 """ 1719 `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material. 1720 1721 It raises an Error if the call to OpenSSL fails. 1722 """ 1723 context = Context(SSLv23_METHOD) 1724 with pytest.raises(Error): 1725 context.set_tlsext_use_srtp(b"SRTP_BOGUS") 1726 1727 def test_set_tlsext_use_srtp_valid(self): 1728 """ 1729 `Context.set_tlsext_use_srtp' enables negotiating SRTP keying material. 1730 1731 It does not return anything. 1732 """ 1733 context = Context(SSLv23_METHOD) 1734 assert context.set_tlsext_use_srtp(b"SRTP_AES128_CM_SHA1_80") is None 1735 1736 1737class TestServerNameCallback(object): 1738 """ 1739 Tests for `Context.set_tlsext_servername_callback` and its 1740 interaction with `Connection`. 1741 """ 1742 1743 def test_old_callback_forgotten(self): 1744 """ 1745 If `Context.set_tlsext_servername_callback` is used to specify 1746 a new callback, the one it replaces is dereferenced. 1747 """ 1748 1749 def callback(connection): # pragma: no cover 1750 pass 1751 1752 def replacement(connection): # pragma: no cover 1753 pass 1754 1755 context = Context(SSLv23_METHOD) 1756 context.set_tlsext_servername_callback(callback) 1757 1758 tracker = ref(callback) 1759 del callback 1760 1761 context.set_tlsext_servername_callback(replacement) 1762 1763 # One run of the garbage collector happens to work on CPython. PyPy 1764 # doesn't collect the underlying object until a second run for whatever 1765 # reason. That's fine, it still demonstrates our code has properly 1766 # dropped the reference. 1767 collect() 1768 collect() 1769 1770 callback = tracker() 1771 if callback is not None: 1772 referrers = get_referrers(callback) 1773 if len(referrers) > 1: # pragma: nocover 1774 pytest.fail("Some references remain: %r" % (referrers,)) 1775 1776 def test_no_servername(self): 1777 """ 1778 When a client specifies no server name, the callback passed to 1779 `Context.set_tlsext_servername_callback` is invoked and the 1780 result of `Connection.get_servername` is `None`. 1781 """ 1782 args = [] 1783 1784 def servername(conn): 1785 args.append((conn, conn.get_servername())) 1786 1787 context = Context(SSLv23_METHOD) 1788 context.set_tlsext_servername_callback(servername) 1789 1790 # Lose our reference to it. The Context is responsible for keeping it 1791 # alive now. 1792 del servername 1793 collect() 1794 1795 # Necessary to actually accept the connection 1796 context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 1797 context.use_certificate( 1798 load_certificate(FILETYPE_PEM, server_cert_pem) 1799 ) 1800 1801 # Do a little connection to trigger the logic 1802 server = Connection(context, None) 1803 server.set_accept_state() 1804 1805 client = Connection(Context(SSLv23_METHOD), None) 1806 client.set_connect_state() 1807 1808 interact_in_memory(server, client) 1809 1810 assert args == [(server, None)] 1811 1812 def test_servername(self): 1813 """ 1814 When a client specifies a server name in its hello message, the 1815 callback passed to `Contexts.set_tlsext_servername_callback` is 1816 invoked and the result of `Connection.get_servername` is that 1817 server name. 1818 """ 1819 args = [] 1820 1821 def servername(conn): 1822 args.append((conn, conn.get_servername())) 1823 1824 context = Context(SSLv23_METHOD) 1825 context.set_tlsext_servername_callback(servername) 1826 1827 # Necessary to actually accept the connection 1828 context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 1829 context.use_certificate( 1830 load_certificate(FILETYPE_PEM, server_cert_pem) 1831 ) 1832 1833 # Do a little connection to trigger the logic 1834 server = Connection(context, None) 1835 server.set_accept_state() 1836 1837 client = Connection(Context(SSLv23_METHOD), None) 1838 client.set_connect_state() 1839 client.set_tlsext_host_name(b"foo1.example.com") 1840 1841 interact_in_memory(server, client) 1842 1843 assert args == [(server, b"foo1.example.com")] 1844 1845 1846class TestApplicationLayerProtoNegotiation(object): 1847 """ 1848 Tests for ALPN in PyOpenSSL. 1849 """ 1850 1851 def test_alpn_success(self): 1852 """ 1853 Clients and servers that agree on the negotiated ALPN protocol can 1854 correct establish a connection, and the agreed protocol is reported 1855 by the connections. 1856 """ 1857 select_args = [] 1858 1859 def select(conn, options): 1860 select_args.append((conn, options)) 1861 return b"spdy/2" 1862 1863 client_context = Context(SSLv23_METHOD) 1864 client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) 1865 1866 server_context = Context(SSLv23_METHOD) 1867 server_context.set_alpn_select_callback(select) 1868 1869 # Necessary to actually accept the connection 1870 server_context.use_privatekey( 1871 load_privatekey(FILETYPE_PEM, server_key_pem) 1872 ) 1873 server_context.use_certificate( 1874 load_certificate(FILETYPE_PEM, server_cert_pem) 1875 ) 1876 1877 # Do a little connection to trigger the logic 1878 server = Connection(server_context, None) 1879 server.set_accept_state() 1880 1881 client = Connection(client_context, None) 1882 client.set_connect_state() 1883 1884 interact_in_memory(server, client) 1885 1886 assert select_args == [(server, [b"http/1.1", b"spdy/2"])] 1887 1888 assert server.get_alpn_proto_negotiated() == b"spdy/2" 1889 assert client.get_alpn_proto_negotiated() == b"spdy/2" 1890 1891 def test_alpn_set_on_connection(self): 1892 """ 1893 The same as test_alpn_success, but setting the ALPN protocols on 1894 the connection rather than the context. 1895 """ 1896 select_args = [] 1897 1898 def select(conn, options): 1899 select_args.append((conn, options)) 1900 return b"spdy/2" 1901 1902 # Setup the client context but don't set any ALPN protocols. 1903 client_context = Context(SSLv23_METHOD) 1904 1905 server_context = Context(SSLv23_METHOD) 1906 server_context.set_alpn_select_callback(select) 1907 1908 # Necessary to actually accept the connection 1909 server_context.use_privatekey( 1910 load_privatekey(FILETYPE_PEM, server_key_pem) 1911 ) 1912 server_context.use_certificate( 1913 load_certificate(FILETYPE_PEM, server_cert_pem) 1914 ) 1915 1916 # Do a little connection to trigger the logic 1917 server = Connection(server_context, None) 1918 server.set_accept_state() 1919 1920 # Set the ALPN protocols on the client connection. 1921 client = Connection(client_context, None) 1922 client.set_alpn_protos([b"http/1.1", b"spdy/2"]) 1923 client.set_connect_state() 1924 1925 interact_in_memory(server, client) 1926 1927 assert select_args == [(server, [b"http/1.1", b"spdy/2"])] 1928 1929 assert server.get_alpn_proto_negotiated() == b"spdy/2" 1930 assert client.get_alpn_proto_negotiated() == b"spdy/2" 1931 1932 def test_alpn_server_fail(self): 1933 """ 1934 When clients and servers cannot agree on what protocol to use next 1935 the TLS connection does not get established. 1936 """ 1937 select_args = [] 1938 1939 def select(conn, options): 1940 select_args.append((conn, options)) 1941 return b"" 1942 1943 client_context = Context(SSLv23_METHOD) 1944 client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) 1945 1946 server_context = Context(SSLv23_METHOD) 1947 server_context.set_alpn_select_callback(select) 1948 1949 # Necessary to actually accept the connection 1950 server_context.use_privatekey( 1951 load_privatekey(FILETYPE_PEM, server_key_pem) 1952 ) 1953 server_context.use_certificate( 1954 load_certificate(FILETYPE_PEM, server_cert_pem) 1955 ) 1956 1957 # Do a little connection to trigger the logic 1958 server = Connection(server_context, None) 1959 server.set_accept_state() 1960 1961 client = Connection(client_context, None) 1962 client.set_connect_state() 1963 1964 # If the client doesn't return anything, the connection will fail. 1965 with pytest.raises(Error): 1966 interact_in_memory(server, client) 1967 1968 assert select_args == [(server, [b"http/1.1", b"spdy/2"])] 1969 1970 def test_alpn_no_server_overlap(self): 1971 """ 1972 A server can allow a TLS handshake to complete without 1973 agreeing to an application protocol by returning 1974 ``NO_OVERLAPPING_PROTOCOLS``. 1975 """ 1976 refusal_args = [] 1977 1978 def refusal(conn, options): 1979 refusal_args.append((conn, options)) 1980 return NO_OVERLAPPING_PROTOCOLS 1981 1982 client_context = Context(SSLv23_METHOD) 1983 client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) 1984 1985 server_context = Context(SSLv23_METHOD) 1986 server_context.set_alpn_select_callback(refusal) 1987 1988 # Necessary to actually accept the connection 1989 server_context.use_privatekey( 1990 load_privatekey(FILETYPE_PEM, server_key_pem) 1991 ) 1992 server_context.use_certificate( 1993 load_certificate(FILETYPE_PEM, server_cert_pem) 1994 ) 1995 1996 # Do a little connection to trigger the logic 1997 server = Connection(server_context, None) 1998 server.set_accept_state() 1999 2000 client = Connection(client_context, None) 2001 client.set_connect_state() 2002 2003 # Do the dance. 2004 interact_in_memory(server, client) 2005 2006 assert refusal_args == [(server, [b"http/1.1", b"spdy/2"])] 2007 2008 assert client.get_alpn_proto_negotiated() == b"" 2009 2010 def test_alpn_select_cb_returns_invalid_value(self): 2011 """ 2012 If the ALPN selection callback returns anything other than 2013 a bytestring or ``NO_OVERLAPPING_PROTOCOLS``, a 2014 :py:exc:`TypeError` is raised. 2015 """ 2016 invalid_cb_args = [] 2017 2018 def invalid_cb(conn, options): 2019 invalid_cb_args.append((conn, options)) 2020 return u"can't return unicode" 2021 2022 client_context = Context(SSLv23_METHOD) 2023 client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) 2024 2025 server_context = Context(SSLv23_METHOD) 2026 server_context.set_alpn_select_callback(invalid_cb) 2027 2028 # Necessary to actually accept the connection 2029 server_context.use_privatekey( 2030 load_privatekey(FILETYPE_PEM, server_key_pem) 2031 ) 2032 server_context.use_certificate( 2033 load_certificate(FILETYPE_PEM, server_cert_pem) 2034 ) 2035 2036 # Do a little connection to trigger the logic 2037 server = Connection(server_context, None) 2038 server.set_accept_state() 2039 2040 client = Connection(client_context, None) 2041 client.set_connect_state() 2042 2043 # Do the dance. 2044 with pytest.raises(TypeError): 2045 interact_in_memory(server, client) 2046 2047 assert invalid_cb_args == [(server, [b"http/1.1", b"spdy/2"])] 2048 2049 assert client.get_alpn_proto_negotiated() == b"" 2050 2051 def test_alpn_no_server(self): 2052 """ 2053 When clients and servers cannot agree on what protocol to use next 2054 because the server doesn't offer ALPN, no protocol is negotiated. 2055 """ 2056 client_context = Context(SSLv23_METHOD) 2057 client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) 2058 2059 server_context = Context(SSLv23_METHOD) 2060 2061 # Necessary to actually accept the connection 2062 server_context.use_privatekey( 2063 load_privatekey(FILETYPE_PEM, server_key_pem) 2064 ) 2065 server_context.use_certificate( 2066 load_certificate(FILETYPE_PEM, server_cert_pem) 2067 ) 2068 2069 # Do a little connection to trigger the logic 2070 server = Connection(server_context, None) 2071 server.set_accept_state() 2072 2073 client = Connection(client_context, None) 2074 client.set_connect_state() 2075 2076 # Do the dance. 2077 interact_in_memory(server, client) 2078 2079 assert client.get_alpn_proto_negotiated() == b"" 2080 2081 def test_alpn_callback_exception(self): 2082 """ 2083 We can handle exceptions in the ALPN select callback. 2084 """ 2085 select_args = [] 2086 2087 def select(conn, options): 2088 select_args.append((conn, options)) 2089 raise TypeError() 2090 2091 client_context = Context(SSLv23_METHOD) 2092 client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) 2093 2094 server_context = Context(SSLv23_METHOD) 2095 server_context.set_alpn_select_callback(select) 2096 2097 # Necessary to actually accept the connection 2098 server_context.use_privatekey( 2099 load_privatekey(FILETYPE_PEM, server_key_pem) 2100 ) 2101 server_context.use_certificate( 2102 load_certificate(FILETYPE_PEM, server_cert_pem) 2103 ) 2104 2105 # Do a little connection to trigger the logic 2106 server = Connection(server_context, None) 2107 server.set_accept_state() 2108 2109 client = Connection(client_context, None) 2110 client.set_connect_state() 2111 2112 with pytest.raises(TypeError): 2113 interact_in_memory(server, client) 2114 assert select_args == [(server, [b"http/1.1", b"spdy/2"])] 2115 2116 2117class TestSession(object): 2118 """ 2119 Unit tests for :py:obj:`OpenSSL.SSL.Session`. 2120 """ 2121 2122 def test_construction(self): 2123 """ 2124 :py:class:`Session` can be constructed with no arguments, creating 2125 a new instance of that type. 2126 """ 2127 new_session = Session() 2128 assert isinstance(new_session, Session) 2129 2130 2131class TestConnection(object): 2132 """ 2133 Unit tests for `OpenSSL.SSL.Connection`. 2134 """ 2135 2136 # XXX get_peer_certificate -> None 2137 # XXX sock_shutdown 2138 # XXX master_key -> TypeError 2139 # XXX server_random -> TypeError 2140 # XXX connect -> TypeError 2141 # XXX connect_ex -> TypeError 2142 # XXX set_connect_state -> TypeError 2143 # XXX set_accept_state -> TypeError 2144 # XXX do_handshake -> TypeError 2145 # XXX bio_read -> TypeError 2146 # XXX recv -> TypeError 2147 # XXX send -> TypeError 2148 # XXX bio_write -> TypeError 2149 2150 def test_type(self): 2151 """ 2152 `Connection` can be used to create instances of that type. 2153 """ 2154 ctx = Context(SSLv23_METHOD) 2155 assert is_consistent_type(Connection, "Connection", ctx, None) 2156 2157 @pytest.mark.parametrize("bad_context", [object(), "context", None, 1]) 2158 def test_wrong_args(self, bad_context): 2159 """ 2160 `Connection.__init__` raises `TypeError` if called with a non-`Context` 2161 instance argument. 2162 """ 2163 with pytest.raises(TypeError): 2164 Connection(bad_context) 2165 2166 @pytest.mark.parametrize("bad_bio", [object(), None, 1, [1, 2, 3]]) 2167 def test_bio_write_wrong_args(self, bad_bio): 2168 """ 2169 `Connection.bio_write` raises `TypeError` if called with a non-bytes 2170 (or text) argument. 2171 """ 2172 context = Context(SSLv23_METHOD) 2173 connection = Connection(context, None) 2174 with pytest.raises(TypeError): 2175 connection.bio_write(bad_bio) 2176 2177 def test_bio_write(self): 2178 """ 2179 `Connection.bio_write` does not raise if called with bytes or 2180 bytearray, warns if called with text. 2181 """ 2182 context = Context(SSLv23_METHOD) 2183 connection = Connection(context, None) 2184 connection.bio_write(b"xy") 2185 connection.bio_write(bytearray(b"za")) 2186 with pytest.warns(DeprecationWarning): 2187 connection.bio_write(u"deprecated") 2188 2189 def test_get_context(self): 2190 """ 2191 `Connection.get_context` returns the `Context` instance used to 2192 construct the `Connection` instance. 2193 """ 2194 context = Context(SSLv23_METHOD) 2195 connection = Connection(context, None) 2196 assert connection.get_context() is context 2197 2198 def test_set_context_wrong_args(self): 2199 """ 2200 `Connection.set_context` raises `TypeError` if called with a 2201 non-`Context` instance argument. 2202 """ 2203 ctx = Context(SSLv23_METHOD) 2204 connection = Connection(ctx, None) 2205 with pytest.raises(TypeError): 2206 connection.set_context(object()) 2207 with pytest.raises(TypeError): 2208 connection.set_context("hello") 2209 with pytest.raises(TypeError): 2210 connection.set_context(1) 2211 assert ctx is connection.get_context() 2212 2213 def test_set_context(self): 2214 """ 2215 `Connection.set_context` specifies a new `Context` instance to be 2216 used for the connection. 2217 """ 2218 original = Context(SSLv23_METHOD) 2219 replacement = Context(SSLv23_METHOD) 2220 connection = Connection(original, None) 2221 connection.set_context(replacement) 2222 assert replacement is connection.get_context() 2223 # Lose our references to the contexts, just in case the Connection 2224 # isn't properly managing its own contributions to their reference 2225 # counts. 2226 del original, replacement 2227 collect() 2228 2229 def test_set_tlsext_host_name_wrong_args(self): 2230 """ 2231 If `Connection.set_tlsext_host_name` is called with a non-byte string 2232 argument or a byte string with an embedded NUL, `TypeError` is raised. 2233 """ 2234 conn = Connection(Context(SSLv23_METHOD), None) 2235 with pytest.raises(TypeError): 2236 conn.set_tlsext_host_name(object()) 2237 with pytest.raises(TypeError): 2238 conn.set_tlsext_host_name(b"with\0null") 2239 2240 if not PY2: 2241 # On Python 3.x, don't accidentally implicitly convert from text. 2242 with pytest.raises(TypeError): 2243 conn.set_tlsext_host_name(b"example.com".decode("ascii")) 2244 2245 def test_pending(self): 2246 """ 2247 `Connection.pending` returns the number of bytes available for 2248 immediate read. 2249 """ 2250 connection = Connection(Context(SSLv23_METHOD), None) 2251 assert connection.pending() == 0 2252 2253 def test_peek(self): 2254 """ 2255 `Connection.recv` peeks into the connection if `socket.MSG_PEEK` is 2256 passed. 2257 """ 2258 server, client = loopback() 2259 server.send(b"xy") 2260 assert client.recv(2, MSG_PEEK) == b"xy" 2261 assert client.recv(2, MSG_PEEK) == b"xy" 2262 assert client.recv(2) == b"xy" 2263 2264 def test_connect_wrong_args(self): 2265 """ 2266 `Connection.connect` raises `TypeError` if called with a non-address 2267 argument. 2268 """ 2269 connection = Connection(Context(SSLv23_METHOD), socket_any_family()) 2270 with pytest.raises(TypeError): 2271 connection.connect(None) 2272 2273 def test_connect_refused(self): 2274 """ 2275 `Connection.connect` raises `socket.error` if the underlying socket 2276 connect method raises it. 2277 """ 2278 client = socket_any_family() 2279 context = Context(SSLv23_METHOD) 2280 clientSSL = Connection(context, client) 2281 # pytest.raises here doesn't work because of a bug in py.test on Python 2282 # 2.6: https://github.com/pytest-dev/pytest/issues/988 2283 try: 2284 clientSSL.connect((loopback_address(client), 1)) 2285 except error as e: 2286 exc = e 2287 assert exc.args[0] == ECONNREFUSED 2288 2289 def test_connect(self): 2290 """ 2291 `Connection.connect` establishes a connection to the specified address. 2292 """ 2293 port = socket_any_family() 2294 port.bind(("", 0)) 2295 port.listen(3) 2296 2297 clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family)) 2298 clientSSL.connect((loopback_address(port), port.getsockname()[1])) 2299 # XXX An assertion? Or something? 2300 2301 @pytest.mark.skipif( 2302 platform == "darwin", 2303 reason="connect_ex sometimes causes a kernel panic on OS X 10.6.4", 2304 ) 2305 def test_connect_ex(self): 2306 """ 2307 If there is a connection error, `Connection.connect_ex` returns the 2308 errno instead of raising an exception. 2309 """ 2310 port = socket_any_family() 2311 port.bind(("", 0)) 2312 port.listen(3) 2313 2314 clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family)) 2315 clientSSL.setblocking(False) 2316 result = clientSSL.connect_ex(port.getsockname()) 2317 expected = (EINPROGRESS, EWOULDBLOCK) 2318 assert result in expected 2319 2320 def test_accept(self): 2321 """ 2322 `Connection.accept` accepts a pending connection attempt and returns a 2323 tuple of a new `Connection` (the accepted client) and the address the 2324 connection originated from. 2325 """ 2326 ctx = Context(SSLv23_METHOD) 2327 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 2328 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 2329 port = socket_any_family() 2330 portSSL = Connection(ctx, port) 2331 portSSL.bind(("", 0)) 2332 portSSL.listen(3) 2333 2334 clientSSL = Connection(Context(SSLv23_METHOD), socket(port.family)) 2335 2336 # Calling portSSL.getsockname() here to get the server IP address 2337 # sounds great, but frequently fails on Windows. 2338 clientSSL.connect((loopback_address(port), portSSL.getsockname()[1])) 2339 2340 serverSSL, address = portSSL.accept() 2341 2342 assert isinstance(serverSSL, Connection) 2343 assert serverSSL.get_context() is ctx 2344 assert address == clientSSL.getsockname() 2345 2346 def test_shutdown_wrong_args(self): 2347 """ 2348 `Connection.set_shutdown` raises `TypeError` if called with arguments 2349 other than integers. 2350 """ 2351 connection = Connection(Context(SSLv23_METHOD), None) 2352 with pytest.raises(TypeError): 2353 connection.set_shutdown(None) 2354 2355 def test_shutdown(self): 2356 """ 2357 `Connection.shutdown` performs an SSL-level connection shutdown. 2358 """ 2359 server, client = loopback() 2360 assert not server.shutdown() 2361 assert server.get_shutdown() == SENT_SHUTDOWN 2362 with pytest.raises(ZeroReturnError): 2363 client.recv(1024) 2364 assert client.get_shutdown() == RECEIVED_SHUTDOWN 2365 client.shutdown() 2366 assert client.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN) 2367 with pytest.raises(ZeroReturnError): 2368 server.recv(1024) 2369 assert server.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN) 2370 2371 def test_shutdown_closed(self): 2372 """ 2373 If the underlying socket is closed, `Connection.shutdown` propagates 2374 the write error from the low level write call. 2375 """ 2376 server, client = loopback() 2377 server.sock_shutdown(2) 2378 with pytest.raises(SysCallError) as exc: 2379 server.shutdown() 2380 if platform == "win32": 2381 assert exc.value.args[0] == ESHUTDOWN 2382 else: 2383 assert exc.value.args[0] == EPIPE 2384 2385 def test_shutdown_truncated(self): 2386 """ 2387 If the underlying connection is truncated, `Connection.shutdown` 2388 raises an `Error`. 2389 """ 2390 server_ctx = Context(SSLv23_METHOD) 2391 client_ctx = Context(SSLv23_METHOD) 2392 server_ctx.use_privatekey( 2393 load_privatekey(FILETYPE_PEM, server_key_pem) 2394 ) 2395 server_ctx.use_certificate( 2396 load_certificate(FILETYPE_PEM, server_cert_pem) 2397 ) 2398 server = Connection(server_ctx, None) 2399 client = Connection(client_ctx, None) 2400 handshake_in_memory(client, server) 2401 assert not server.shutdown() 2402 with pytest.raises(WantReadError): 2403 server.shutdown() 2404 server.bio_shutdown() 2405 with pytest.raises(Error): 2406 server.shutdown() 2407 2408 def test_set_shutdown(self): 2409 """ 2410 `Connection.set_shutdown` sets the state of the SSL connection 2411 shutdown process. 2412 """ 2413 connection = Connection(Context(SSLv23_METHOD), socket_any_family()) 2414 connection.set_shutdown(RECEIVED_SHUTDOWN) 2415 assert connection.get_shutdown() == RECEIVED_SHUTDOWN 2416 2417 def test_state_string(self): 2418 """ 2419 `Connection.state_string` verbosely describes the current state of 2420 the `Connection`. 2421 """ 2422 server, client = socket_pair() 2423 server = loopback_server_factory(server) 2424 client = loopback_client_factory(client) 2425 2426 assert server.get_state_string() in [ 2427 b"before/accept initialization", 2428 b"before SSL initialization", 2429 ] 2430 assert client.get_state_string() in [ 2431 b"before/connect initialization", 2432 b"before SSL initialization", 2433 ] 2434 2435 def test_app_data(self): 2436 """ 2437 Any object can be set as app data by passing it to 2438 `Connection.set_app_data` and later retrieved with 2439 `Connection.get_app_data`. 2440 """ 2441 conn = Connection(Context(SSLv23_METHOD), None) 2442 assert None is conn.get_app_data() 2443 app_data = object() 2444 conn.set_app_data(app_data) 2445 assert conn.get_app_data() is app_data 2446 2447 def test_makefile(self): 2448 """ 2449 `Connection.makefile` is not implemented and calling that 2450 method raises `NotImplementedError`. 2451 """ 2452 conn = Connection(Context(SSLv23_METHOD), None) 2453 with pytest.raises(NotImplementedError): 2454 conn.makefile() 2455 2456 def test_get_certificate(self): 2457 """ 2458 `Connection.get_certificate` returns the local certificate. 2459 """ 2460 chain = _create_certificate_chain() 2461 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 2462 2463 context = Context(SSLv23_METHOD) 2464 context.use_certificate(scert) 2465 client = Connection(context, None) 2466 cert = client.get_certificate() 2467 assert cert is not None 2468 assert "Server Certificate" == cert.get_subject().CN 2469 2470 def test_get_certificate_none(self): 2471 """ 2472 `Connection.get_certificate` returns the local certificate. 2473 2474 If there is no certificate, it returns None. 2475 """ 2476 context = Context(SSLv23_METHOD) 2477 client = Connection(context, None) 2478 cert = client.get_certificate() 2479 assert cert is None 2480 2481 def test_get_peer_cert_chain(self): 2482 """ 2483 `Connection.get_peer_cert_chain` returns a list of certificates 2484 which the connected server returned for the certification verification. 2485 """ 2486 chain = _create_certificate_chain() 2487 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 2488 2489 serverContext = Context(SSLv23_METHOD) 2490 serverContext.use_privatekey(skey) 2491 serverContext.use_certificate(scert) 2492 serverContext.add_extra_chain_cert(icert) 2493 serverContext.add_extra_chain_cert(cacert) 2494 server = Connection(serverContext, None) 2495 server.set_accept_state() 2496 2497 # Create the client 2498 clientContext = Context(SSLv23_METHOD) 2499 clientContext.set_verify(VERIFY_NONE, verify_cb) 2500 client = Connection(clientContext, None) 2501 client.set_connect_state() 2502 2503 interact_in_memory(client, server) 2504 2505 chain = client.get_peer_cert_chain() 2506 assert len(chain) == 3 2507 assert "Server Certificate" == chain[0].get_subject().CN 2508 assert "Intermediate Certificate" == chain[1].get_subject().CN 2509 assert "Authority Certificate" == chain[2].get_subject().CN 2510 2511 def test_get_peer_cert_chain_none(self): 2512 """ 2513 `Connection.get_peer_cert_chain` returns `None` if the peer sends 2514 no certificate chain. 2515 """ 2516 ctx = Context(SSLv23_METHOD) 2517 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 2518 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 2519 server = Connection(ctx, None) 2520 server.set_accept_state() 2521 client = Connection(Context(SSLv23_METHOD), None) 2522 client.set_connect_state() 2523 interact_in_memory(client, server) 2524 assert None is server.get_peer_cert_chain() 2525 2526 def test_get_verified_chain(self): 2527 """ 2528 `Connection.get_verified_chain` returns a list of certificates 2529 which the connected server returned for the certification verification. 2530 """ 2531 chain = _create_certificate_chain() 2532 [(cakey, cacert), (ikey, icert), (skey, scert)] = chain 2533 2534 serverContext = Context(SSLv23_METHOD) 2535 serverContext.use_privatekey(skey) 2536 serverContext.use_certificate(scert) 2537 serverContext.add_extra_chain_cert(icert) 2538 serverContext.add_extra_chain_cert(cacert) 2539 server = Connection(serverContext, None) 2540 server.set_accept_state() 2541 2542 # Create the client 2543 clientContext = Context(SSLv23_METHOD) 2544 # cacert is self-signed so the client must trust it for verification 2545 # to succeed. 2546 clientContext.get_cert_store().add_cert(cacert) 2547 clientContext.set_verify(VERIFY_PEER, verify_cb) 2548 client = Connection(clientContext, None) 2549 client.set_connect_state() 2550 2551 interact_in_memory(client, server) 2552 2553 chain = client.get_verified_chain() 2554 assert len(chain) == 3 2555 assert "Server Certificate" == chain[0].get_subject().CN 2556 assert "Intermediate Certificate" == chain[1].get_subject().CN 2557 assert "Authority Certificate" == chain[2].get_subject().CN 2558 2559 def test_get_verified_chain_none(self): 2560 """ 2561 `Connection.get_verified_chain` returns `None` if the peer sends 2562 no certificate chain. 2563 """ 2564 ctx = Context(SSLv23_METHOD) 2565 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 2566 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 2567 server = Connection(ctx, None) 2568 server.set_accept_state() 2569 client = Connection(Context(SSLv23_METHOD), None) 2570 client.set_connect_state() 2571 interact_in_memory(client, server) 2572 assert None is server.get_verified_chain() 2573 2574 def test_get_verified_chain_unconnected(self): 2575 """ 2576 `Connection.get_verified_chain` returns `None` when used with an object 2577 which has not been connected. 2578 """ 2579 ctx = Context(SSLv23_METHOD) 2580 server = Connection(ctx, None) 2581 assert None is server.get_verified_chain() 2582 2583 def test_get_session_unconnected(self): 2584 """ 2585 `Connection.get_session` returns `None` when used with an object 2586 which has not been connected. 2587 """ 2588 ctx = Context(SSLv23_METHOD) 2589 server = Connection(ctx, None) 2590 session = server.get_session() 2591 assert None is session 2592 2593 def test_server_get_session(self): 2594 """ 2595 On the server side of a connection, `Connection.get_session` returns a 2596 `Session` instance representing the SSL session for that connection. 2597 """ 2598 server, client = loopback() 2599 session = server.get_session() 2600 assert isinstance(session, Session) 2601 2602 def test_client_get_session(self): 2603 """ 2604 On the client side of a connection, `Connection.get_session` 2605 returns a `Session` instance representing the SSL session for 2606 that connection. 2607 """ 2608 server, client = loopback() 2609 session = client.get_session() 2610 assert isinstance(session, Session) 2611 2612 def test_set_session_wrong_args(self): 2613 """ 2614 `Connection.set_session` raises `TypeError` if called with an object 2615 that is not an instance of `Session`. 2616 """ 2617 ctx = Context(SSLv23_METHOD) 2618 connection = Connection(ctx, None) 2619 with pytest.raises(TypeError): 2620 connection.set_session(123) 2621 with pytest.raises(TypeError): 2622 connection.set_session("hello") 2623 with pytest.raises(TypeError): 2624 connection.set_session(object()) 2625 2626 def test_client_set_session(self): 2627 """ 2628 `Connection.set_session`, when used prior to a connection being 2629 established, accepts a `Session` instance and causes an attempt to 2630 re-use the session it represents when the SSL handshake is performed. 2631 """ 2632 key = load_privatekey(FILETYPE_PEM, server_key_pem) 2633 cert = load_certificate(FILETYPE_PEM, server_cert_pem) 2634 ctx = Context(TLSv1_2_METHOD) 2635 ctx.use_privatekey(key) 2636 ctx.use_certificate(cert) 2637 ctx.set_session_id("unity-test") 2638 2639 def makeServer(socket): 2640 server = Connection(ctx, socket) 2641 server.set_accept_state() 2642 return server 2643 2644 originalServer, originalClient = loopback(server_factory=makeServer) 2645 originalSession = originalClient.get_session() 2646 2647 def makeClient(socket): 2648 client = loopback_client_factory(socket) 2649 client.set_session(originalSession) 2650 return client 2651 2652 resumedServer, resumedClient = loopback( 2653 server_factory=makeServer, client_factory=makeClient 2654 ) 2655 2656 # This is a proxy: in general, we have no access to any unique 2657 # identifier for the session (new enough versions of OpenSSL expose 2658 # a hash which could be usable, but "new enough" is very, very new). 2659 # Instead, exploit the fact that the master key is re-used if the 2660 # session is re-used. As long as the master key for the two 2661 # connections is the same, the session was re-used! 2662 assert originalServer.master_key() == resumedServer.master_key() 2663 2664 def test_set_session_wrong_method(self): 2665 """ 2666 If `Connection.set_session` is passed a `Session` instance associated 2667 with a context using a different SSL method than the `Connection` 2668 is using, a `OpenSSL.SSL.Error` is raised. 2669 """ 2670 v1 = TLSv1_2_METHOD 2671 v2 = TLSv1_METHOD 2672 2673 key = load_privatekey(FILETYPE_PEM, server_key_pem) 2674 cert = load_certificate(FILETYPE_PEM, server_cert_pem) 2675 ctx = Context(v1) 2676 ctx.use_privatekey(key) 2677 ctx.use_certificate(cert) 2678 ctx.set_session_id(b"unity-test") 2679 2680 def makeServer(socket): 2681 server = Connection(ctx, socket) 2682 server.set_accept_state() 2683 return server 2684 2685 def makeOriginalClient(socket): 2686 client = Connection(Context(v1), socket) 2687 client.set_connect_state() 2688 return client 2689 2690 originalServer, originalClient = loopback( 2691 server_factory=makeServer, client_factory=makeOriginalClient 2692 ) 2693 originalSession = originalClient.get_session() 2694 2695 def makeClient(socket): 2696 # Intentionally use a different, incompatible method here. 2697 client = Connection(Context(v2), socket) 2698 client.set_connect_state() 2699 client.set_session(originalSession) 2700 return client 2701 2702 with pytest.raises(Error): 2703 loopback(client_factory=makeClient, server_factory=makeServer) 2704 2705 def test_wantWriteError(self): 2706 """ 2707 `Connection` methods which generate output raise 2708 `OpenSSL.SSL.WantWriteError` if writing to the connection's BIO 2709 fail indicating a should-write state. 2710 """ 2711 client_socket, server_socket = socket_pair() 2712 # Fill up the client's send buffer so Connection won't be able to write 2713 # anything. Only write a single byte at a time so we can be sure we 2714 # completely fill the buffer. Even though the socket API is allowed to 2715 # signal a short write via its return value it seems this doesn't 2716 # always happen on all platforms (FreeBSD and OS X particular) for the 2717 # very last bit of available buffer space. 2718 msg = b"x" 2719 for i in range(1024 * 1024 * 64): 2720 try: 2721 client_socket.send(msg) 2722 except error as e: 2723 if e.errno == EWOULDBLOCK: 2724 break 2725 raise 2726 else: 2727 pytest.fail( 2728 "Failed to fill socket buffer, cannot test BIO want write" 2729 ) 2730 2731 ctx = Context(SSLv23_METHOD) 2732 conn = Connection(ctx, client_socket) 2733 # Client's speak first, so make it an SSL client 2734 conn.set_connect_state() 2735 with pytest.raises(WantWriteError): 2736 conn.do_handshake() 2737 2738 # XXX want_read 2739 2740 def test_get_finished_before_connect(self): 2741 """ 2742 `Connection.get_finished` returns `None` before TLS handshake 2743 is completed. 2744 """ 2745 ctx = Context(SSLv23_METHOD) 2746 connection = Connection(ctx, None) 2747 assert connection.get_finished() is None 2748 2749 def test_get_peer_finished_before_connect(self): 2750 """ 2751 `Connection.get_peer_finished` returns `None` before TLS handshake 2752 is completed. 2753 """ 2754 ctx = Context(SSLv23_METHOD) 2755 connection = Connection(ctx, None) 2756 assert connection.get_peer_finished() is None 2757 2758 def test_get_finished(self): 2759 """ 2760 `Connection.get_finished` method returns the TLS Finished message send 2761 from client, or server. Finished messages are send during 2762 TLS handshake. 2763 """ 2764 server, client = loopback() 2765 2766 assert server.get_finished() is not None 2767 assert len(server.get_finished()) > 0 2768 2769 def test_get_peer_finished(self): 2770 """ 2771 `Connection.get_peer_finished` method returns the TLS Finished 2772 message received from client, or server. Finished messages are send 2773 during TLS handshake. 2774 """ 2775 server, client = loopback() 2776 2777 assert server.get_peer_finished() is not None 2778 assert len(server.get_peer_finished()) > 0 2779 2780 def test_tls_finished_message_symmetry(self): 2781 """ 2782 The TLS Finished message send by server must be the TLS Finished 2783 message received by client. 2784 2785 The TLS Finished message send by client must be the TLS Finished 2786 message received by server. 2787 """ 2788 server, client = loopback() 2789 2790 assert server.get_finished() == client.get_peer_finished() 2791 assert client.get_finished() == server.get_peer_finished() 2792 2793 def test_get_cipher_name_before_connect(self): 2794 """ 2795 `Connection.get_cipher_name` returns `None` if no connection 2796 has been established. 2797 """ 2798 ctx = Context(SSLv23_METHOD) 2799 conn = Connection(ctx, None) 2800 assert conn.get_cipher_name() is None 2801 2802 def test_get_cipher_name(self): 2803 """ 2804 `Connection.get_cipher_name` returns a `unicode` string giving the 2805 name of the currently used cipher. 2806 """ 2807 server, client = loopback() 2808 server_cipher_name, client_cipher_name = ( 2809 server.get_cipher_name(), 2810 client.get_cipher_name(), 2811 ) 2812 2813 assert isinstance(server_cipher_name, text_type) 2814 assert isinstance(client_cipher_name, text_type) 2815 2816 assert server_cipher_name == client_cipher_name 2817 2818 def test_get_cipher_version_before_connect(self): 2819 """ 2820 `Connection.get_cipher_version` returns `None` if no connection 2821 has been established. 2822 """ 2823 ctx = Context(SSLv23_METHOD) 2824 conn = Connection(ctx, None) 2825 assert conn.get_cipher_version() is None 2826 2827 def test_get_cipher_version(self): 2828 """ 2829 `Connection.get_cipher_version` returns a `unicode` string giving 2830 the protocol name of the currently used cipher. 2831 """ 2832 server, client = loopback() 2833 server_cipher_version, client_cipher_version = ( 2834 server.get_cipher_version(), 2835 client.get_cipher_version(), 2836 ) 2837 2838 assert isinstance(server_cipher_version, text_type) 2839 assert isinstance(client_cipher_version, text_type) 2840 2841 assert server_cipher_version == client_cipher_version 2842 2843 def test_get_cipher_bits_before_connect(self): 2844 """ 2845 `Connection.get_cipher_bits` returns `None` if no connection has 2846 been established. 2847 """ 2848 ctx = Context(SSLv23_METHOD) 2849 conn = Connection(ctx, None) 2850 assert conn.get_cipher_bits() is None 2851 2852 def test_get_cipher_bits(self): 2853 """ 2854 `Connection.get_cipher_bits` returns the number of secret bits 2855 of the currently used cipher. 2856 """ 2857 server, client = loopback() 2858 server_cipher_bits, client_cipher_bits = ( 2859 server.get_cipher_bits(), 2860 client.get_cipher_bits(), 2861 ) 2862 2863 assert isinstance(server_cipher_bits, int) 2864 assert isinstance(client_cipher_bits, int) 2865 2866 assert server_cipher_bits == client_cipher_bits 2867 2868 def test_get_protocol_version_name(self): 2869 """ 2870 `Connection.get_protocol_version_name()` returns a string giving the 2871 protocol version of the current connection. 2872 """ 2873 server, client = loopback() 2874 client_protocol_version_name = client.get_protocol_version_name() 2875 server_protocol_version_name = server.get_protocol_version_name() 2876 2877 assert isinstance(server_protocol_version_name, text_type) 2878 assert isinstance(client_protocol_version_name, text_type) 2879 2880 assert server_protocol_version_name == client_protocol_version_name 2881 2882 def test_get_protocol_version(self): 2883 """ 2884 `Connection.get_protocol_version()` returns an integer 2885 giving the protocol version of the current connection. 2886 """ 2887 server, client = loopback() 2888 client_protocol_version = client.get_protocol_version() 2889 server_protocol_version = server.get_protocol_version() 2890 2891 assert isinstance(server_protocol_version, int) 2892 assert isinstance(client_protocol_version, int) 2893 2894 assert server_protocol_version == client_protocol_version 2895 2896 def test_wantReadError(self): 2897 """ 2898 `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are 2899 no bytes available to be read from the BIO. 2900 """ 2901 ctx = Context(SSLv23_METHOD) 2902 conn = Connection(ctx, None) 2903 with pytest.raises(WantReadError): 2904 conn.bio_read(1024) 2905 2906 @pytest.mark.parametrize("bufsize", [1.0, None, object(), "bufsize"]) 2907 def test_bio_read_wrong_args(self, bufsize): 2908 """ 2909 `Connection.bio_read` raises `TypeError` if passed a non-integer 2910 argument. 2911 """ 2912 ctx = Context(SSLv23_METHOD) 2913 conn = Connection(ctx, None) 2914 with pytest.raises(TypeError): 2915 conn.bio_read(bufsize) 2916 2917 def test_buffer_size(self): 2918 """ 2919 `Connection.bio_read` accepts an integer giving the maximum number 2920 of bytes to read and return. 2921 """ 2922 ctx = Context(SSLv23_METHOD) 2923 conn = Connection(ctx, None) 2924 conn.set_connect_state() 2925 try: 2926 conn.do_handshake() 2927 except WantReadError: 2928 pass 2929 data = conn.bio_read(2) 2930 assert 2 == len(data) 2931 2932 2933class TestConnectionGetCipherList(object): 2934 """ 2935 Tests for `Connection.get_cipher_list`. 2936 """ 2937 2938 def test_result(self): 2939 """ 2940 `Connection.get_cipher_list` returns a list of `bytes` giving the 2941 names of the ciphers which might be used. 2942 """ 2943 connection = Connection(Context(SSLv23_METHOD), None) 2944 ciphers = connection.get_cipher_list() 2945 assert isinstance(ciphers, list) 2946 for cipher in ciphers: 2947 assert isinstance(cipher, str) 2948 2949 2950class VeryLarge(bytes): 2951 """ 2952 Mock object so that we don't have to allocate 2**31 bytes 2953 """ 2954 2955 def __len__(self): 2956 return 2 ** 31 2957 2958 2959class TestConnectionSend(object): 2960 """ 2961 Tests for `Connection.send`. 2962 """ 2963 2964 def test_wrong_args(self): 2965 """ 2966 When called with arguments other than string argument for its first 2967 parameter, `Connection.send` raises `TypeError`. 2968 """ 2969 connection = Connection(Context(SSLv23_METHOD), None) 2970 with pytest.raises(TypeError): 2971 connection.send(object()) 2972 with pytest.raises(TypeError): 2973 connection.send([1, 2, 3]) 2974 2975 def test_short_bytes(self): 2976 """ 2977 When passed a short byte string, `Connection.send` transmits all of it 2978 and returns the number of bytes sent. 2979 """ 2980 server, client = loopback() 2981 count = server.send(b"xy") 2982 assert count == 2 2983 assert client.recv(2) == b"xy" 2984 2985 def test_text(self): 2986 """ 2987 When passed a text, `Connection.send` transmits all of it and 2988 returns the number of bytes sent. It also raises a DeprecationWarning. 2989 """ 2990 server, client = loopback() 2991 with pytest.warns(DeprecationWarning) as w: 2992 simplefilter("always") 2993 count = server.send(b"xy".decode("ascii")) 2994 assert "{0} for buf is no longer accepted, use bytes".format( 2995 WARNING_TYPE_EXPECTED 2996 ) == str(w[-1].message) 2997 assert count == 2 2998 assert client.recv(2) == b"xy" 2999 3000 def test_short_memoryview(self): 3001 """ 3002 When passed a memoryview onto a small number of bytes, 3003 `Connection.send` transmits all of them and returns the number 3004 of bytes sent. 3005 """ 3006 server, client = loopback() 3007 count = server.send(memoryview(b"xy")) 3008 assert count == 2 3009 assert client.recv(2) == b"xy" 3010 3011 def test_short_bytearray(self): 3012 """ 3013 When passed a short bytearray, `Connection.send` transmits all of 3014 it and returns the number of bytes sent. 3015 """ 3016 server, client = loopback() 3017 count = server.send(bytearray(b"xy")) 3018 assert count == 2 3019 assert client.recv(2) == b"xy" 3020 3021 @skip_if_py3 3022 def test_short_buffer(self): 3023 """ 3024 When passed a buffer containing a small number of bytes, 3025 `Connection.send` transmits all of them and returns the number 3026 of bytes sent. 3027 """ 3028 server, client = loopback() 3029 count = server.send(buffer(b"xy")) # noqa: F821 3030 assert count == 2 3031 assert client.recv(2) == b"xy" 3032 3033 @pytest.mark.skipif( 3034 sys.maxsize < 2 ** 31, 3035 reason="sys.maxsize < 2**31 - test requires 64 bit", 3036 ) 3037 def test_buf_too_large(self): 3038 """ 3039 When passed a buffer containing >= 2**31 bytes, 3040 `Connection.send` bails out as SSL_write only 3041 accepts an int for the buffer length. 3042 """ 3043 connection = Connection(Context(SSLv23_METHOD), None) 3044 with pytest.raises(ValueError) as exc_info: 3045 connection.send(VeryLarge()) 3046 exc_info.match(r"Cannot send more than .+ bytes at once") 3047 3048 3049def _make_memoryview(size): 3050 """ 3051 Create a new ``memoryview`` wrapped around a ``bytearray`` of the given 3052 size. 3053 """ 3054 return memoryview(bytearray(size)) 3055 3056 3057class TestConnectionRecvInto(object): 3058 """ 3059 Tests for `Connection.recv_into`. 3060 """ 3061 3062 def _no_length_test(self, factory): 3063 """ 3064 Assert that when the given buffer is passed to `Connection.recv_into`, 3065 whatever bytes are available to be received that fit into that buffer 3066 are written into that buffer. 3067 """ 3068 output_buffer = factory(5) 3069 3070 server, client = loopback() 3071 server.send(b"xy") 3072 3073 assert client.recv_into(output_buffer) == 2 3074 assert output_buffer == bytearray(b"xy\x00\x00\x00") 3075 3076 def test_bytearray_no_length(self): 3077 """ 3078 `Connection.recv_into` can be passed a `bytearray` instance and data 3079 in the receive buffer is written to it. 3080 """ 3081 self._no_length_test(bytearray) 3082 3083 def _respects_length_test(self, factory): 3084 """ 3085 Assert that when the given buffer is passed to `Connection.recv_into` 3086 along with a value for `nbytes` that is less than the size of that 3087 buffer, only `nbytes` bytes are written into the buffer. 3088 """ 3089 output_buffer = factory(10) 3090 3091 server, client = loopback() 3092 server.send(b"abcdefghij") 3093 3094 assert client.recv_into(output_buffer, 5) == 5 3095 assert output_buffer == bytearray(b"abcde\x00\x00\x00\x00\x00") 3096 3097 def test_bytearray_respects_length(self): 3098 """ 3099 When called with a `bytearray` instance, `Connection.recv_into` 3100 respects the `nbytes` parameter and doesn't copy in more than that 3101 number of bytes. 3102 """ 3103 self._respects_length_test(bytearray) 3104 3105 def _doesnt_overfill_test(self, factory): 3106 """ 3107 Assert that if there are more bytes available to be read from the 3108 receive buffer than would fit into the buffer passed to 3109 `Connection.recv_into`, only as many as fit are written into it. 3110 """ 3111 output_buffer = factory(5) 3112 3113 server, client = loopback() 3114 server.send(b"abcdefghij") 3115 3116 assert client.recv_into(output_buffer) == 5 3117 assert output_buffer == bytearray(b"abcde") 3118 rest = client.recv(5) 3119 assert b"fghij" == rest 3120 3121 def test_bytearray_doesnt_overfill(self): 3122 """ 3123 When called with a `bytearray` instance, `Connection.recv_into` 3124 respects the size of the array and doesn't write more bytes into it 3125 than will fit. 3126 """ 3127 self._doesnt_overfill_test(bytearray) 3128 3129 def test_bytearray_really_doesnt_overfill(self): 3130 """ 3131 When called with a `bytearray` instance and an `nbytes` value that is 3132 too large, `Connection.recv_into` respects the size of the array and 3133 not the `nbytes` value and doesn't write more bytes into the buffer 3134 than will fit. 3135 """ 3136 self._doesnt_overfill_test(bytearray) 3137 3138 def test_peek(self): 3139 server, client = loopback() 3140 server.send(b"xy") 3141 3142 for _ in range(2): 3143 output_buffer = bytearray(5) 3144 assert client.recv_into(output_buffer, flags=MSG_PEEK) == 2 3145 assert output_buffer == bytearray(b"xy\x00\x00\x00") 3146 3147 def test_memoryview_no_length(self): 3148 """ 3149 `Connection.recv_into` can be passed a `memoryview` instance and data 3150 in the receive buffer is written to it. 3151 """ 3152 self._no_length_test(_make_memoryview) 3153 3154 def test_memoryview_respects_length(self): 3155 """ 3156 When called with a `memoryview` instance, `Connection.recv_into` 3157 respects the ``nbytes`` parameter and doesn't copy more than that 3158 number of bytes in. 3159 """ 3160 self._respects_length_test(_make_memoryview) 3161 3162 def test_memoryview_doesnt_overfill(self): 3163 """ 3164 When called with a `memoryview` instance, `Connection.recv_into` 3165 respects the size of the array and doesn't write more bytes into it 3166 than will fit. 3167 """ 3168 self._doesnt_overfill_test(_make_memoryview) 3169 3170 def test_memoryview_really_doesnt_overfill(self): 3171 """ 3172 When called with a `memoryview` instance and an `nbytes` value that is 3173 too large, `Connection.recv_into` respects the size of the array and 3174 not the `nbytes` value and doesn't write more bytes into the buffer 3175 than will fit. 3176 """ 3177 self._doesnt_overfill_test(_make_memoryview) 3178 3179 3180class TestConnectionSendall(object): 3181 """ 3182 Tests for `Connection.sendall`. 3183 """ 3184 3185 def test_wrong_args(self): 3186 """ 3187 When called with arguments other than a string argument for its first 3188 parameter, `Connection.sendall` raises `TypeError`. 3189 """ 3190 connection = Connection(Context(SSLv23_METHOD), None) 3191 with pytest.raises(TypeError): 3192 connection.sendall(object()) 3193 with pytest.raises(TypeError): 3194 connection.sendall([1, 2, 3]) 3195 3196 def test_short(self): 3197 """ 3198 `Connection.sendall` transmits all of the bytes in the string 3199 passed to it. 3200 """ 3201 server, client = loopback() 3202 server.sendall(b"x") 3203 assert client.recv(1) == b"x" 3204 3205 def test_text(self): 3206 """ 3207 `Connection.sendall` transmits all the content in the string passed 3208 to it, raising a DeprecationWarning in case of this being a text. 3209 """ 3210 server, client = loopback() 3211 with pytest.warns(DeprecationWarning) as w: 3212 simplefilter("always") 3213 server.sendall(b"x".decode("ascii")) 3214 assert "{0} for buf is no longer accepted, use bytes".format( 3215 WARNING_TYPE_EXPECTED 3216 ) == str(w[-1].message) 3217 assert client.recv(1) == b"x" 3218 3219 def test_short_memoryview(self): 3220 """ 3221 When passed a memoryview onto a small number of bytes, 3222 `Connection.sendall` transmits all of them. 3223 """ 3224 server, client = loopback() 3225 server.sendall(memoryview(b"x")) 3226 assert client.recv(1) == b"x" 3227 3228 @skip_if_py3 3229 def test_short_buffers(self): 3230 """ 3231 When passed a buffer containing a small number of bytes, 3232 `Connection.sendall` transmits all of them. 3233 """ 3234 server, client = loopback() 3235 count = server.sendall(buffer(b"xy")) # noqa: F821 3236 assert count == 2 3237 assert client.recv(2) == b"xy" 3238 3239 def test_long(self): 3240 """ 3241 `Connection.sendall` transmits all the bytes in the string passed to it 3242 even if this requires multiple calls of an underlying write function. 3243 """ 3244 server, client = loopback() 3245 # Should be enough, underlying SSL_write should only do 16k at a time. 3246 # On Windows, after 32k of bytes the write will block (forever 3247 # - because no one is yet reading). 3248 message = b"x" * (1024 * 32 - 1) + b"y" 3249 server.sendall(message) 3250 accum = [] 3251 received = 0 3252 while received < len(message): 3253 data = client.recv(1024) 3254 accum.append(data) 3255 received += len(data) 3256 assert message == b"".join(accum) 3257 3258 def test_closed(self): 3259 """ 3260 If the underlying socket is closed, `Connection.sendall` propagates the 3261 write error from the low level write call. 3262 """ 3263 server, client = loopback() 3264 server.sock_shutdown(2) 3265 with pytest.raises(SysCallError) as err: 3266 server.sendall(b"hello, world") 3267 if platform == "win32": 3268 assert err.value.args[0] == ESHUTDOWN 3269 else: 3270 assert err.value.args[0] == EPIPE 3271 3272 3273class TestConnectionRenegotiate(object): 3274 """ 3275 Tests for SSL renegotiation APIs. 3276 """ 3277 3278 def test_total_renegotiations(self): 3279 """ 3280 `Connection.total_renegotiations` returns `0` before any renegotiations 3281 have happened. 3282 """ 3283 connection = Connection(Context(SSLv23_METHOD), None) 3284 assert connection.total_renegotiations() == 0 3285 3286 def test_renegotiate(self): 3287 """ 3288 Go through a complete renegotiation cycle. 3289 """ 3290 server, client = loopback( 3291 lambda s: loopback_server_factory(s, TLSv1_2_METHOD), 3292 lambda s: loopback_client_factory(s, TLSv1_2_METHOD), 3293 ) 3294 3295 server.send(b"hello world") 3296 3297 assert b"hello world" == client.recv(len(b"hello world")) 3298 3299 assert 0 == server.total_renegotiations() 3300 assert False is server.renegotiate_pending() 3301 3302 assert True is server.renegotiate() 3303 3304 assert True is server.renegotiate_pending() 3305 3306 server.setblocking(False) 3307 client.setblocking(False) 3308 3309 client.do_handshake() 3310 server.do_handshake() 3311 3312 assert 1 == server.total_renegotiations() 3313 while False is server.renegotiate_pending(): 3314 pass 3315 3316 3317class TestError(object): 3318 """ 3319 Unit tests for `OpenSSL.SSL.Error`. 3320 """ 3321 3322 def test_type(self): 3323 """ 3324 `Error` is an exception type. 3325 """ 3326 assert issubclass(Error, Exception) 3327 assert Error.__name__ == "Error" 3328 3329 3330class TestConstants(object): 3331 """ 3332 Tests for the values of constants exposed in `OpenSSL.SSL`. 3333 3334 These are values defined by OpenSSL intended only to be used as flags to 3335 OpenSSL APIs. The only assertions it seems can be made about them is 3336 their values. 3337 """ 3338 3339 @pytest.mark.skipif( 3340 OP_NO_QUERY_MTU is None, 3341 reason="OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old", 3342 ) 3343 def test_op_no_query_mtu(self): 3344 """ 3345 The value of `OpenSSL.SSL.OP_NO_QUERY_MTU` is 0x1000, the value 3346 of `SSL_OP_NO_QUERY_MTU` defined by `openssl/ssl.h`. 3347 """ 3348 assert OP_NO_QUERY_MTU == 0x1000 3349 3350 @pytest.mark.skipif( 3351 OP_COOKIE_EXCHANGE is None, 3352 reason="OP_COOKIE_EXCHANGE unavailable - " 3353 "OpenSSL version may be too old", 3354 ) 3355 def test_op_cookie_exchange(self): 3356 """ 3357 The value of `OpenSSL.SSL.OP_COOKIE_EXCHANGE` is 0x2000, the 3358 value of `SSL_OP_COOKIE_EXCHANGE` defined by `openssl/ssl.h`. 3359 """ 3360 assert OP_COOKIE_EXCHANGE == 0x2000 3361 3362 @pytest.mark.skipif( 3363 OP_NO_TICKET is None, 3364 reason="OP_NO_TICKET unavailable - OpenSSL version may be too old", 3365 ) 3366 def test_op_no_ticket(self): 3367 """ 3368 The value of `OpenSSL.SSL.OP_NO_TICKET` is 0x4000, the value of 3369 `SSL_OP_NO_TICKET` defined by `openssl/ssl.h`. 3370 """ 3371 assert OP_NO_TICKET == 0x4000 3372 3373 @pytest.mark.skipif( 3374 OP_NO_COMPRESSION is None, 3375 reason=( 3376 "OP_NO_COMPRESSION unavailable - OpenSSL version may be too old" 3377 ), 3378 ) 3379 def test_op_no_compression(self): 3380 """ 3381 The value of `OpenSSL.SSL.OP_NO_COMPRESSION` is 0x20000, the 3382 value of `SSL_OP_NO_COMPRESSION` defined by `openssl/ssl.h`. 3383 """ 3384 assert OP_NO_COMPRESSION == 0x20000 3385 3386 def test_sess_cache_off(self): 3387 """ 3388 The value of `OpenSSL.SSL.SESS_CACHE_OFF` 0x0, the value of 3389 `SSL_SESS_CACHE_OFF` defined by `openssl/ssl.h`. 3390 """ 3391 assert 0x0 == SESS_CACHE_OFF 3392 3393 def test_sess_cache_client(self): 3394 """ 3395 The value of `OpenSSL.SSL.SESS_CACHE_CLIENT` 0x1, the value of 3396 `SSL_SESS_CACHE_CLIENT` defined by `openssl/ssl.h`. 3397 """ 3398 assert 0x1 == SESS_CACHE_CLIENT 3399 3400 def test_sess_cache_server(self): 3401 """ 3402 The value of `OpenSSL.SSL.SESS_CACHE_SERVER` 0x2, the value of 3403 `SSL_SESS_CACHE_SERVER` defined by `openssl/ssl.h`. 3404 """ 3405 assert 0x2 == SESS_CACHE_SERVER 3406 3407 def test_sess_cache_both(self): 3408 """ 3409 The value of `OpenSSL.SSL.SESS_CACHE_BOTH` 0x3, the value of 3410 `SSL_SESS_CACHE_BOTH` defined by `openssl/ssl.h`. 3411 """ 3412 assert 0x3 == SESS_CACHE_BOTH 3413 3414 def test_sess_cache_no_auto_clear(self): 3415 """ 3416 The value of `OpenSSL.SSL.SESS_CACHE_NO_AUTO_CLEAR` 0x80, the 3417 value of `SSL_SESS_CACHE_NO_AUTO_CLEAR` defined by 3418 `openssl/ssl.h`. 3419 """ 3420 assert 0x80 == SESS_CACHE_NO_AUTO_CLEAR 3421 3422 def test_sess_cache_no_internal_lookup(self): 3423 """ 3424 The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_LOOKUP` 0x100, 3425 the value of `SSL_SESS_CACHE_NO_INTERNAL_LOOKUP` defined by 3426 `openssl/ssl.h`. 3427 """ 3428 assert 0x100 == SESS_CACHE_NO_INTERNAL_LOOKUP 3429 3430 def test_sess_cache_no_internal_store(self): 3431 """ 3432 The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL_STORE` 0x200, 3433 the value of `SSL_SESS_CACHE_NO_INTERNAL_STORE` defined by 3434 `openssl/ssl.h`. 3435 """ 3436 assert 0x200 == SESS_CACHE_NO_INTERNAL_STORE 3437 3438 def test_sess_cache_no_internal(self): 3439 """ 3440 The value of `OpenSSL.SSL.SESS_CACHE_NO_INTERNAL` 0x300, the 3441 value of `SSL_SESS_CACHE_NO_INTERNAL` defined by 3442 `openssl/ssl.h`. 3443 """ 3444 assert 0x300 == SESS_CACHE_NO_INTERNAL 3445 3446 3447class TestMemoryBIO(object): 3448 """ 3449 Tests for `OpenSSL.SSL.Connection` using a memory BIO. 3450 """ 3451 3452 def _server(self, sock): 3453 """ 3454 Create a new server-side SSL `Connection` object wrapped around `sock`. 3455 """ 3456 # Create the server side Connection. This is mostly setup boilerplate 3457 # - use TLSv1, use a particular certificate, etc. 3458 server_ctx = Context(SSLv23_METHOD) 3459 server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) 3460 server_ctx.set_verify( 3461 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, 3462 verify_cb, 3463 ) 3464 server_store = server_ctx.get_cert_store() 3465 server_ctx.use_privatekey( 3466 load_privatekey(FILETYPE_PEM, server_key_pem) 3467 ) 3468 server_ctx.use_certificate( 3469 load_certificate(FILETYPE_PEM, server_cert_pem) 3470 ) 3471 server_ctx.check_privatekey() 3472 server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) 3473 # Here the Connection is actually created. If None is passed as the 3474 # 2nd parameter, it indicates a memory BIO should be created. 3475 server_conn = Connection(server_ctx, sock) 3476 server_conn.set_accept_state() 3477 return server_conn 3478 3479 def _client(self, sock): 3480 """ 3481 Create a new client-side SSL `Connection` object wrapped around `sock`. 3482 """ 3483 # Now create the client side Connection. Similar boilerplate to the 3484 # above. 3485 client_ctx = Context(SSLv23_METHOD) 3486 client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) 3487 client_ctx.set_verify( 3488 VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT | VERIFY_CLIENT_ONCE, 3489 verify_cb, 3490 ) 3491 client_store = client_ctx.get_cert_store() 3492 client_ctx.use_privatekey( 3493 load_privatekey(FILETYPE_PEM, client_key_pem) 3494 ) 3495 client_ctx.use_certificate( 3496 load_certificate(FILETYPE_PEM, client_cert_pem) 3497 ) 3498 client_ctx.check_privatekey() 3499 client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) 3500 client_conn = Connection(client_ctx, sock) 3501 client_conn.set_connect_state() 3502 return client_conn 3503 3504 def test_memory_connect(self): 3505 """ 3506 Two `Connection`s which use memory BIOs can be manually connected by 3507 reading from the output of each and writing those bytes to the input of 3508 the other and in this way establish a connection and exchange 3509 application-level bytes with each other. 3510 """ 3511 server_conn = self._server(None) 3512 client_conn = self._client(None) 3513 3514 # There should be no key or nonces yet. 3515 assert server_conn.master_key() is None 3516 assert server_conn.client_random() is None 3517 assert server_conn.server_random() is None 3518 3519 # First, the handshake needs to happen. We'll deliver bytes back and 3520 # forth between the client and server until neither of them feels like 3521 # speaking any more. 3522 assert interact_in_memory(client_conn, server_conn) is None 3523 3524 # Now that the handshake is done, there should be a key and nonces. 3525 assert server_conn.master_key() is not None 3526 assert server_conn.client_random() is not None 3527 assert server_conn.server_random() is not None 3528 assert server_conn.client_random() == client_conn.client_random() 3529 assert server_conn.server_random() == client_conn.server_random() 3530 assert server_conn.client_random() != server_conn.server_random() 3531 assert client_conn.client_random() != client_conn.server_random() 3532 3533 # Export key material for other uses. 3534 cekm = client_conn.export_keying_material(b"LABEL", 32) 3535 sekm = server_conn.export_keying_material(b"LABEL", 32) 3536 assert cekm is not None 3537 assert sekm is not None 3538 assert cekm == sekm 3539 assert len(sekm) == 32 3540 3541 # Export key material for other uses with additional context. 3542 cekmc = client_conn.export_keying_material(b"LABEL", 32, b"CONTEXT") 3543 sekmc = server_conn.export_keying_material(b"LABEL", 32, b"CONTEXT") 3544 assert cekmc is not None 3545 assert sekmc is not None 3546 assert cekmc == sekmc 3547 assert cekmc != cekm 3548 assert sekmc != sekm 3549 # Export with alternate label 3550 cekmt = client_conn.export_keying_material(b"test", 32, b"CONTEXT") 3551 sekmt = server_conn.export_keying_material(b"test", 32, b"CONTEXT") 3552 assert cekmc != cekmt 3553 assert sekmc != sekmt 3554 3555 # Here are the bytes we'll try to send. 3556 important_message = b"One if by land, two if by sea." 3557 3558 server_conn.write(important_message) 3559 assert interact_in_memory(client_conn, server_conn) == ( 3560 client_conn, 3561 important_message, 3562 ) 3563 3564 client_conn.write(important_message[::-1]) 3565 assert interact_in_memory(client_conn, server_conn) == ( 3566 server_conn, 3567 important_message[::-1], 3568 ) 3569 3570 def test_socket_connect(self): 3571 """ 3572 Just like `test_memory_connect` but with an actual socket. 3573 3574 This is primarily to rule out the memory BIO code as the source of any 3575 problems encountered while passing data over a `Connection` (if 3576 this test fails, there must be a problem outside the memory BIO code, 3577 as no memory BIO is involved here). Even though this isn't a memory 3578 BIO test, it's convenient to have it here. 3579 """ 3580 server_conn, client_conn = loopback() 3581 3582 important_message = b"Help me Obi Wan Kenobi, you're my only hope." 3583 client_conn.send(important_message) 3584 msg = server_conn.recv(1024) 3585 assert msg == important_message 3586 3587 # Again in the other direction, just for fun. 3588 important_message = important_message[::-1] 3589 server_conn.send(important_message) 3590 msg = client_conn.recv(1024) 3591 assert msg == important_message 3592 3593 def test_socket_overrides_memory(self): 3594 """ 3595 Test that `OpenSSL.SSL.bio_read` and `OpenSSL.SSL.bio_write` don't 3596 work on `OpenSSL.SSL.Connection`() that use sockets. 3597 """ 3598 context = Context(SSLv23_METHOD) 3599 client = socket_any_family() 3600 clientSSL = Connection(context, client) 3601 with pytest.raises(TypeError): 3602 clientSSL.bio_read(100) 3603 with pytest.raises(TypeError): 3604 clientSSL.bio_write(b"foo") 3605 with pytest.raises(TypeError): 3606 clientSSL.bio_shutdown() 3607 3608 def test_outgoing_overflow(self): 3609 """ 3610 If more bytes than can be written to the memory BIO are passed to 3611 `Connection.send` at once, the number of bytes which were written is 3612 returned and that many bytes from the beginning of the input can be 3613 read from the other end of the connection. 3614 """ 3615 server = self._server(None) 3616 client = self._client(None) 3617 3618 interact_in_memory(client, server) 3619 3620 size = 2 ** 15 3621 sent = client.send(b"x" * size) 3622 # Sanity check. We're trying to test what happens when the entire 3623 # input can't be sent. If the entire input was sent, this test is 3624 # meaningless. 3625 assert sent < size 3626 3627 receiver, received = interact_in_memory(client, server) 3628 assert receiver is server 3629 3630 # We can rely on all of these bytes being received at once because 3631 # loopback passes 2 ** 16 to recv - more than 2 ** 15. 3632 assert len(received) == sent 3633 3634 def test_shutdown(self): 3635 """ 3636 `Connection.bio_shutdown` signals the end of the data stream 3637 from which the `Connection` reads. 3638 """ 3639 server = self._server(None) 3640 server.bio_shutdown() 3641 with pytest.raises(Error) as err: 3642 server.recv(1024) 3643 # We don't want WantReadError or ZeroReturnError or anything - it's a 3644 # handshake failure. 3645 assert type(err.value) in [Error, SysCallError] 3646 3647 def test_unexpected_EOF(self): 3648 """ 3649 If the connection is lost before an orderly SSL shutdown occurs, 3650 `OpenSSL.SSL.SysCallError` is raised with a message of 3651 "Unexpected EOF". 3652 """ 3653 server_conn, client_conn = loopback() 3654 client_conn.sock_shutdown(SHUT_RDWR) 3655 with pytest.raises(SysCallError) as err: 3656 server_conn.recv(1024) 3657 assert err.value.args == (-1, "Unexpected EOF") 3658 3659 def _check_client_ca_list(self, func): 3660 """ 3661 Verify the return value of the `get_client_ca_list` method for 3662 server and client connections. 3663 3664 :param func: A function which will be called with the server context 3665 before the client and server are connected to each other. This 3666 function should specify a list of CAs for the server to send to the 3667 client and return that same list. The list will be used to verify 3668 that `get_client_ca_list` returns the proper value at 3669 various times. 3670 """ 3671 server = self._server(None) 3672 client = self._client(None) 3673 assert client.get_client_ca_list() == [] 3674 assert server.get_client_ca_list() == [] 3675 ctx = server.get_context() 3676 expected = func(ctx) 3677 assert client.get_client_ca_list() == [] 3678 assert server.get_client_ca_list() == expected 3679 interact_in_memory(client, server) 3680 assert client.get_client_ca_list() == expected 3681 assert server.get_client_ca_list() == expected 3682 3683 def test_set_client_ca_list_errors(self): 3684 """ 3685 `Context.set_client_ca_list` raises a `TypeError` if called with a 3686 non-list or a list that contains objects other than X509Names. 3687 """ 3688 ctx = Context(SSLv23_METHOD) 3689 with pytest.raises(TypeError): 3690 ctx.set_client_ca_list("spam") 3691 with pytest.raises(TypeError): 3692 ctx.set_client_ca_list(["spam"]) 3693 3694 def test_set_empty_ca_list(self): 3695 """ 3696 If passed an empty list, `Context.set_client_ca_list` configures the 3697 context to send no CA names to the client and, on both the server and 3698 client sides, `Connection.get_client_ca_list` returns an empty list 3699 after the connection is set up. 3700 """ 3701 3702 def no_ca(ctx): 3703 ctx.set_client_ca_list([]) 3704 return [] 3705 3706 self._check_client_ca_list(no_ca) 3707 3708 def test_set_one_ca_list(self): 3709 """ 3710 If passed a list containing a single X509Name, 3711 `Context.set_client_ca_list` configures the context to send 3712 that CA name to the client and, on both the server and client sides, 3713 `Connection.get_client_ca_list` returns a list containing that 3714 X509Name after the connection is set up. 3715 """ 3716 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3717 cadesc = cacert.get_subject() 3718 3719 def single_ca(ctx): 3720 ctx.set_client_ca_list([cadesc]) 3721 return [cadesc] 3722 3723 self._check_client_ca_list(single_ca) 3724 3725 def test_set_multiple_ca_list(self): 3726 """ 3727 If passed a list containing multiple X509Name objects, 3728 `Context.set_client_ca_list` configures the context to send 3729 those CA names to the client and, on both the server and client sides, 3730 `Connection.get_client_ca_list` returns a list containing those 3731 X509Names after the connection is set up. 3732 """ 3733 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3734 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3735 3736 sedesc = secert.get_subject() 3737 cldesc = clcert.get_subject() 3738 3739 def multiple_ca(ctx): 3740 L = [sedesc, cldesc] 3741 ctx.set_client_ca_list(L) 3742 return L 3743 3744 self._check_client_ca_list(multiple_ca) 3745 3746 def test_reset_ca_list(self): 3747 """ 3748 If called multiple times, only the X509Names passed to the final call 3749 of `Context.set_client_ca_list` are used to configure the CA 3750 names sent to the client. 3751 """ 3752 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3753 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3754 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3755 3756 cadesc = cacert.get_subject() 3757 sedesc = secert.get_subject() 3758 cldesc = clcert.get_subject() 3759 3760 def changed_ca(ctx): 3761 ctx.set_client_ca_list([sedesc, cldesc]) 3762 ctx.set_client_ca_list([cadesc]) 3763 return [cadesc] 3764 3765 self._check_client_ca_list(changed_ca) 3766 3767 def test_mutated_ca_list(self): 3768 """ 3769 If the list passed to `Context.set_client_ca_list` is mutated 3770 afterwards, this does not affect the list of CA names sent to the 3771 client. 3772 """ 3773 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3774 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3775 3776 cadesc = cacert.get_subject() 3777 sedesc = secert.get_subject() 3778 3779 def mutated_ca(ctx): 3780 L = [cadesc] 3781 ctx.set_client_ca_list([cadesc]) 3782 L.append(sedesc) 3783 return [cadesc] 3784 3785 self._check_client_ca_list(mutated_ca) 3786 3787 def test_add_client_ca_wrong_args(self): 3788 """ 3789 `Context.add_client_ca` raises `TypeError` if called with 3790 a non-X509 object. 3791 """ 3792 ctx = Context(SSLv23_METHOD) 3793 with pytest.raises(TypeError): 3794 ctx.add_client_ca("spam") 3795 3796 def test_one_add_client_ca(self): 3797 """ 3798 A certificate's subject can be added as a CA to be sent to the client 3799 with `Context.add_client_ca`. 3800 """ 3801 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3802 cadesc = cacert.get_subject() 3803 3804 def single_ca(ctx): 3805 ctx.add_client_ca(cacert) 3806 return [cadesc] 3807 3808 self._check_client_ca_list(single_ca) 3809 3810 def test_multiple_add_client_ca(self): 3811 """ 3812 Multiple CA names can be sent to the client by calling 3813 `Context.add_client_ca` with multiple X509 objects. 3814 """ 3815 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3816 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3817 3818 cadesc = cacert.get_subject() 3819 sedesc = secert.get_subject() 3820 3821 def multiple_ca(ctx): 3822 ctx.add_client_ca(cacert) 3823 ctx.add_client_ca(secert) 3824 return [cadesc, sedesc] 3825 3826 self._check_client_ca_list(multiple_ca) 3827 3828 def test_set_and_add_client_ca(self): 3829 """ 3830 A call to `Context.set_client_ca_list` followed by a call to 3831 `Context.add_client_ca` results in using the CA names from the 3832 first call and the CA name from the second call. 3833 """ 3834 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3835 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3836 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3837 3838 cadesc = cacert.get_subject() 3839 sedesc = secert.get_subject() 3840 cldesc = clcert.get_subject() 3841 3842 def mixed_set_add_ca(ctx): 3843 ctx.set_client_ca_list([cadesc, sedesc]) 3844 ctx.add_client_ca(clcert) 3845 return [cadesc, sedesc, cldesc] 3846 3847 self._check_client_ca_list(mixed_set_add_ca) 3848 3849 def test_set_after_add_client_ca(self): 3850 """ 3851 A call to `Context.set_client_ca_list` after a call to 3852 `Context.add_client_ca` replaces the CA name specified by the 3853 former call with the names specified by the latter call. 3854 """ 3855 cacert = load_certificate(FILETYPE_PEM, root_cert_pem) 3856 secert = load_certificate(FILETYPE_PEM, server_cert_pem) 3857 clcert = load_certificate(FILETYPE_PEM, server_cert_pem) 3858 3859 cadesc = cacert.get_subject() 3860 sedesc = secert.get_subject() 3861 3862 def set_replaces_add_ca(ctx): 3863 ctx.add_client_ca(clcert) 3864 ctx.set_client_ca_list([cadesc]) 3865 ctx.add_client_ca(secert) 3866 return [cadesc, sedesc] 3867 3868 self._check_client_ca_list(set_replaces_add_ca) 3869 3870 3871class TestInfoConstants(object): 3872 """ 3873 Tests for assorted constants exposed for use in info callbacks. 3874 """ 3875 3876 def test_integers(self): 3877 """ 3878 All of the info constants are integers. 3879 3880 This is a very weak test. It would be nice to have one that actually 3881 verifies that as certain info events happen, the value passed to the 3882 info callback matches up with the constant exposed by OpenSSL.SSL. 3883 """ 3884 for const in [ 3885 SSL_ST_CONNECT, 3886 SSL_ST_ACCEPT, 3887 SSL_ST_MASK, 3888 SSL_CB_LOOP, 3889 SSL_CB_EXIT, 3890 SSL_CB_READ, 3891 SSL_CB_WRITE, 3892 SSL_CB_ALERT, 3893 SSL_CB_READ_ALERT, 3894 SSL_CB_WRITE_ALERT, 3895 SSL_CB_ACCEPT_LOOP, 3896 SSL_CB_ACCEPT_EXIT, 3897 SSL_CB_CONNECT_LOOP, 3898 SSL_CB_CONNECT_EXIT, 3899 SSL_CB_HANDSHAKE_START, 3900 SSL_CB_HANDSHAKE_DONE, 3901 ]: 3902 assert isinstance(const, int) 3903 3904 # These constants don't exist on OpenSSL 1.1.0 3905 for const in [ 3906 SSL_ST_INIT, 3907 SSL_ST_BEFORE, 3908 SSL_ST_OK, 3909 SSL_ST_RENEGOTIATE, 3910 ]: 3911 assert const is None or isinstance(const, int) 3912 3913 3914class TestRequires(object): 3915 """ 3916 Tests for the decorator factory used to conditionally raise 3917 NotImplementedError when older OpenSSLs are used. 3918 """ 3919 3920 def test_available(self): 3921 """ 3922 When the OpenSSL functionality is available the decorated functions 3923 work appropriately. 3924 """ 3925 feature_guard = _make_requires(True, "Error text") 3926 results = [] 3927 3928 @feature_guard 3929 def inner(): 3930 results.append(True) 3931 return True 3932 3933 assert inner() is True 3934 assert [True] == results 3935 3936 def test_unavailable(self): 3937 """ 3938 When the OpenSSL functionality is not available the decorated function 3939 does not execute and NotImplementedError is raised. 3940 """ 3941 feature_guard = _make_requires(False, "Error text") 3942 3943 @feature_guard 3944 def inner(): # pragma: nocover 3945 pytest.fail("Should not be called") 3946 3947 with pytest.raises(NotImplementedError) as e: 3948 inner() 3949 3950 assert "Error text" in str(e.value) 3951 3952 3953class TestOCSP(object): 3954 """ 3955 Tests for PyOpenSSL's OCSP stapling support. 3956 """ 3957 3958 sample_ocsp_data = b"this is totally ocsp data" 3959 3960 def _client_connection(self, callback, data, request_ocsp=True): 3961 """ 3962 Builds a client connection suitable for using OCSP. 3963 3964 :param callback: The callback to register for OCSP. 3965 :param data: The opaque data object that will be handed to the 3966 OCSP callback. 3967 :param request_ocsp: Whether the client will actually ask for OCSP 3968 stapling. Useful for testing only. 3969 """ 3970 ctx = Context(SSLv23_METHOD) 3971 ctx.set_ocsp_client_callback(callback, data) 3972 client = Connection(ctx) 3973 3974 if request_ocsp: 3975 client.request_ocsp() 3976 3977 client.set_connect_state() 3978 return client 3979 3980 def _server_connection(self, callback, data): 3981 """ 3982 Builds a server connection suitable for using OCSP. 3983 3984 :param callback: The callback to register for OCSP. 3985 :param data: The opaque data object that will be handed to the 3986 OCSP callback. 3987 """ 3988 ctx = Context(SSLv23_METHOD) 3989 ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) 3990 ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) 3991 ctx.set_ocsp_server_callback(callback, data) 3992 server = Connection(ctx) 3993 server.set_accept_state() 3994 return server 3995 3996 def test_callbacks_arent_called_by_default(self): 3997 """ 3998 If both the client and the server have registered OCSP callbacks, but 3999 the client does not send the OCSP request, neither callback gets 4000 called. 4001 """ 4002 4003 def ocsp_callback(*args, **kwargs): # pragma: nocover 4004 pytest.fail("Should not be called") 4005 4006 client = self._client_connection( 4007 callback=ocsp_callback, data=None, request_ocsp=False 4008 ) 4009 server = self._server_connection(callback=ocsp_callback, data=None) 4010 handshake_in_memory(client, server) 4011 4012 def test_client_negotiates_without_server(self): 4013 """ 4014 If the client wants to do OCSP but the server does not, the handshake 4015 succeeds, and the client callback fires with an empty byte string. 4016 """ 4017 called = [] 4018 4019 def ocsp_callback(conn, ocsp_data, ignored): 4020 called.append(ocsp_data) 4021 return True 4022 4023 client = self._client_connection(callback=ocsp_callback, data=None) 4024 server = loopback_server_factory(socket=None) 4025 handshake_in_memory(client, server) 4026 4027 assert len(called) == 1 4028 assert called[0] == b"" 4029 4030 def test_client_receives_servers_data(self): 4031 """ 4032 The data the server sends in its callback is received by the client. 4033 """ 4034 calls = [] 4035 4036 def server_callback(*args, **kwargs): 4037 return self.sample_ocsp_data 4038 4039 def client_callback(conn, ocsp_data, ignored): 4040 calls.append(ocsp_data) 4041 return True 4042 4043 client = self._client_connection(callback=client_callback, data=None) 4044 server = self._server_connection(callback=server_callback, data=None) 4045 handshake_in_memory(client, server) 4046 4047 assert len(calls) == 1 4048 assert calls[0] == self.sample_ocsp_data 4049 4050 def test_callbacks_are_invoked_with_connections(self): 4051 """ 4052 The first arguments to both callbacks are their respective connections. 4053 """ 4054 client_calls = [] 4055 server_calls = [] 4056 4057 def client_callback(conn, *args, **kwargs): 4058 client_calls.append(conn) 4059 return True 4060 4061 def server_callback(conn, *args, **kwargs): 4062 server_calls.append(conn) 4063 return self.sample_ocsp_data 4064 4065 client = self._client_connection(callback=client_callback, data=None) 4066 server = self._server_connection(callback=server_callback, data=None) 4067 handshake_in_memory(client, server) 4068 4069 assert len(client_calls) == 1 4070 assert len(server_calls) == 1 4071 assert client_calls[0] is client 4072 assert server_calls[0] is server 4073 4074 def test_opaque_data_is_passed_through(self): 4075 """ 4076 Both callbacks receive an opaque, user-provided piece of data in their 4077 callbacks as the final argument. 4078 """ 4079 calls = [] 4080 4081 def server_callback(*args): 4082 calls.append(args) 4083 return self.sample_ocsp_data 4084 4085 def client_callback(*args): 4086 calls.append(args) 4087 return True 4088 4089 sentinel = object() 4090 4091 client = self._client_connection( 4092 callback=client_callback, data=sentinel 4093 ) 4094 server = self._server_connection( 4095 callback=server_callback, data=sentinel 4096 ) 4097 handshake_in_memory(client, server) 4098 4099 assert len(calls) == 2 4100 assert calls[0][-1] is sentinel 4101 assert calls[1][-1] is sentinel 4102 4103 def test_server_returns_empty_string(self): 4104 """ 4105 If the server returns an empty bytestring from its callback, the 4106 client callback is called with the empty bytestring. 4107 """ 4108 client_calls = [] 4109 4110 def server_callback(*args): 4111 return b"" 4112 4113 def client_callback(conn, ocsp_data, ignored): 4114 client_calls.append(ocsp_data) 4115 return True 4116 4117 client = self._client_connection(callback=client_callback, data=None) 4118 server = self._server_connection(callback=server_callback, data=None) 4119 handshake_in_memory(client, server) 4120 4121 assert len(client_calls) == 1 4122 assert client_calls[0] == b"" 4123 4124 def test_client_returns_false_terminates_handshake(self): 4125 """ 4126 If the client returns False from its callback, the handshake fails. 4127 """ 4128 4129 def server_callback(*args): 4130 return self.sample_ocsp_data 4131 4132 def client_callback(*args): 4133 return False 4134 4135 client = self._client_connection(callback=client_callback, data=None) 4136 server = self._server_connection(callback=server_callback, data=None) 4137 4138 with pytest.raises(Error): 4139 handshake_in_memory(client, server) 4140 4141 def test_exceptions_in_client_bubble_up(self): 4142 """ 4143 The callbacks thrown in the client callback bubble up to the caller. 4144 """ 4145 4146 class SentinelException(Exception): 4147 pass 4148 4149 def server_callback(*args): 4150 return self.sample_ocsp_data 4151 4152 def client_callback(*args): 4153 raise SentinelException() 4154 4155 client = self._client_connection(callback=client_callback, data=None) 4156 server = self._server_connection(callback=server_callback, data=None) 4157 4158 with pytest.raises(SentinelException): 4159 handshake_in_memory(client, server) 4160 4161 def test_exceptions_in_server_bubble_up(self): 4162 """ 4163 The callbacks thrown in the server callback bubble up to the caller. 4164 """ 4165 4166 class SentinelException(Exception): 4167 pass 4168 4169 def server_callback(*args): 4170 raise SentinelException() 4171 4172 def client_callback(*args): # pragma: nocover 4173 pytest.fail("Should not be called") 4174 4175 client = self._client_connection(callback=client_callback, data=None) 4176 server = self._server_connection(callback=server_callback, data=None) 4177 4178 with pytest.raises(SentinelException): 4179 handshake_in_memory(client, server) 4180 4181 def test_server_must_return_bytes(self): 4182 """ 4183 The server callback must return a bytestring, or a TypeError is thrown. 4184 """ 4185 4186 def server_callback(*args): 4187 return self.sample_ocsp_data.decode("ascii") 4188 4189 def client_callback(*args): # pragma: nocover 4190 pytest.fail("Should not be called") 4191 4192 client = self._client_connection(callback=client_callback, data=None) 4193 server = self._server_connection(callback=server_callback, data=None) 4194 4195 with pytest.raises(TypeError): 4196 handshake_in_memory(client, server) 4197