1# This file is dual licensed under the terms of the Apache License, Version 2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 3# for complete details. 4 5from __future__ import absolute_import, division, print_function 6 7import binascii 8import itertools 9import os 10 11import pytest 12 13from cryptography.exceptions import ( 14 AlreadyFinalized, AlreadyUpdated, InvalidSignature, InvalidTag, 15 NotYetFinalized 16) 17from cryptography.hazmat.primitives import hashes, hmac 18from cryptography.hazmat.primitives.asymmetric import rsa 19from cryptography.hazmat.primitives.ciphers import Cipher 20from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand 21from cryptography.hazmat.primitives.kdf.kbkdf import ( 22 CounterLocation, KBKDFHMAC, Mode 23) 24from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC 25 26from ...utils import load_vectors_from_file 27 28 29def _load_all_params(path, file_names, param_loader): 30 all_params = [] 31 for file_name in file_names: 32 all_params.extend( 33 load_vectors_from_file(os.path.join(path, file_name), param_loader) 34 ) 35 return all_params 36 37 38def generate_encrypt_test(param_loader, path, file_names, cipher_factory, 39 mode_factory): 40 all_params = _load_all_params(path, file_names, param_loader) 41 42 @pytest.mark.parametrize("params", all_params) 43 def test_encryption(self, backend, params): 44 encrypt_test(backend, cipher_factory, mode_factory, params) 45 46 return test_encryption 47 48 49def encrypt_test(backend, cipher_factory, mode_factory, params): 50 assert backend.cipher_supported( 51 cipher_factory(**params), mode_factory(**params) 52 ) 53 54 plaintext = params["plaintext"] 55 ciphertext = params["ciphertext"] 56 cipher = Cipher( 57 cipher_factory(**params), 58 mode_factory(**params), 59 backend=backend 60 ) 61 encryptor = cipher.encryptor() 62 actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) 63 actual_ciphertext += encryptor.finalize() 64 assert actual_ciphertext == binascii.unhexlify(ciphertext) 65 decryptor = cipher.decryptor() 66 actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) 67 actual_plaintext += decryptor.finalize() 68 assert actual_plaintext == binascii.unhexlify(plaintext) 69 70 71def generate_aead_test(param_loader, path, file_names, cipher_factory, 72 mode_factory): 73 all_params = _load_all_params(path, file_names, param_loader) 74 75 @pytest.mark.parametrize("params", all_params) 76 def test_aead(self, backend, params): 77 aead_test(backend, cipher_factory, mode_factory, params) 78 79 return test_aead 80 81 82def aead_test(backend, cipher_factory, mode_factory, params): 83 if params.get("pt") is not None: 84 plaintext = params["pt"] 85 ciphertext = params["ct"] 86 aad = params["aad"] 87 if params.get("fail") is True: 88 cipher = Cipher( 89 cipher_factory(binascii.unhexlify(params["key"])), 90 mode_factory(binascii.unhexlify(params["iv"]), 91 binascii.unhexlify(params["tag"]), 92 len(binascii.unhexlify(params["tag"]))), 93 backend 94 ) 95 decryptor = cipher.decryptor() 96 decryptor.authenticate_additional_data(binascii.unhexlify(aad)) 97 actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) 98 with pytest.raises(InvalidTag): 99 decryptor.finalize() 100 else: 101 cipher = Cipher( 102 cipher_factory(binascii.unhexlify(params["key"])), 103 mode_factory(binascii.unhexlify(params["iv"]), None), 104 backend 105 ) 106 encryptor = cipher.encryptor() 107 encryptor.authenticate_additional_data(binascii.unhexlify(aad)) 108 actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) 109 actual_ciphertext += encryptor.finalize() 110 tag_len = len(binascii.unhexlify(params["tag"])) 111 assert binascii.hexlify(encryptor.tag[:tag_len]) == params["tag"] 112 cipher = Cipher( 113 cipher_factory(binascii.unhexlify(params["key"])), 114 mode_factory(binascii.unhexlify(params["iv"]), 115 binascii.unhexlify(params["tag"]), 116 min_tag_length=tag_len), 117 backend 118 ) 119 decryptor = cipher.decryptor() 120 decryptor.authenticate_additional_data(binascii.unhexlify(aad)) 121 actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) 122 actual_plaintext += decryptor.finalize() 123 assert actual_plaintext == binascii.unhexlify(plaintext) 124 125 126def generate_stream_encryption_test(param_loader, path, file_names, 127 cipher_factory): 128 all_params = _load_all_params(path, file_names, param_loader) 129 130 @pytest.mark.parametrize("params", all_params) 131 def test_stream_encryption(self, backend, params): 132 stream_encryption_test(backend, cipher_factory, params) 133 return test_stream_encryption 134 135 136def stream_encryption_test(backend, cipher_factory, params): 137 plaintext = params["plaintext"] 138 ciphertext = params["ciphertext"] 139 offset = params["offset"] 140 cipher = Cipher(cipher_factory(**params), None, backend=backend) 141 encryptor = cipher.encryptor() 142 # throw away offset bytes 143 encryptor.update(b"\x00" * int(offset)) 144 actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) 145 actual_ciphertext += encryptor.finalize() 146 assert actual_ciphertext == binascii.unhexlify(ciphertext) 147 decryptor = cipher.decryptor() 148 decryptor.update(b"\x00" * int(offset)) 149 actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) 150 actual_plaintext += decryptor.finalize() 151 assert actual_plaintext == binascii.unhexlify(plaintext) 152 153 154def generate_hash_test(param_loader, path, file_names, hash_cls): 155 all_params = _load_all_params(path, file_names, param_loader) 156 157 @pytest.mark.parametrize("params", all_params) 158 def test_hash(self, backend, params): 159 hash_test(backend, hash_cls, params) 160 return test_hash 161 162 163def hash_test(backend, algorithm, params): 164 msg, md = params 165 m = hashes.Hash(algorithm, backend=backend) 166 m.update(binascii.unhexlify(msg)) 167 expected_md = md.replace(" ", "").lower().encode("ascii") 168 assert m.finalize() == binascii.unhexlify(expected_md) 169 170 171def generate_base_hash_test(algorithm, digest_size): 172 def test_base_hash(self, backend): 173 base_hash_test(backend, algorithm, digest_size) 174 return test_base_hash 175 176 177def base_hash_test(backend, algorithm, digest_size): 178 m = hashes.Hash(algorithm, backend=backend) 179 assert m.algorithm.digest_size == digest_size 180 m_copy = m.copy() 181 assert m != m_copy 182 assert m._ctx != m_copy._ctx 183 184 m.update(b"abc") 185 copy = m.copy() 186 copy.update(b"123") 187 m.update(b"123") 188 assert copy.finalize() == m.finalize() 189 190 191def generate_base_hmac_test(hash_cls): 192 def test_base_hmac(self, backend): 193 base_hmac_test(backend, hash_cls) 194 return test_base_hmac 195 196 197def base_hmac_test(backend, algorithm): 198 key = b"ab" 199 h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) 200 h_copy = h.copy() 201 assert h != h_copy 202 assert h._ctx != h_copy._ctx 203 204 205def generate_hmac_test(param_loader, path, file_names, algorithm): 206 all_params = _load_all_params(path, file_names, param_loader) 207 208 @pytest.mark.parametrize("params", all_params) 209 def test_hmac(self, backend, params): 210 hmac_test(backend, algorithm, params) 211 return test_hmac 212 213 214def hmac_test(backend, algorithm, params): 215 msg, md, key = params 216 h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) 217 h.update(binascii.unhexlify(msg)) 218 assert h.finalize() == binascii.unhexlify(md.encode("ascii")) 219 220 221def generate_pbkdf2_test(param_loader, path, file_names, algorithm): 222 all_params = _load_all_params(path, file_names, param_loader) 223 224 @pytest.mark.parametrize("params", all_params) 225 def test_pbkdf2(self, backend, params): 226 pbkdf2_test(backend, algorithm, params) 227 return test_pbkdf2 228 229 230def pbkdf2_test(backend, algorithm, params): 231 # Password and salt can contain \0, which should be loaded as a null char. 232 # The NIST loader loads them as literal strings so we replace with the 233 # proper value. 234 kdf = PBKDF2HMAC( 235 algorithm, 236 int(params["length"]), 237 params["salt"], 238 int(params["iterations"]), 239 backend 240 ) 241 derived_key = kdf.derive(params["password"]) 242 assert binascii.hexlify(derived_key) == params["derived_key"] 243 244 245def generate_aead_exception_test(cipher_factory, mode_factory): 246 def test_aead_exception(self, backend): 247 aead_exception_test(backend, cipher_factory, mode_factory) 248 return test_aead_exception 249 250 251def aead_exception_test(backend, cipher_factory, mode_factory): 252 cipher = Cipher( 253 cipher_factory(binascii.unhexlify(b"0" * 32)), 254 mode_factory(binascii.unhexlify(b"0" * 24)), 255 backend 256 ) 257 encryptor = cipher.encryptor() 258 encryptor.update(b"a" * 16) 259 with pytest.raises(NotYetFinalized): 260 encryptor.tag 261 with pytest.raises(AlreadyUpdated): 262 encryptor.authenticate_additional_data(b"b" * 16) 263 encryptor.finalize() 264 with pytest.raises(AlreadyFinalized): 265 encryptor.authenticate_additional_data(b"b" * 16) 266 with pytest.raises(AlreadyFinalized): 267 encryptor.update(b"b" * 16) 268 with pytest.raises(AlreadyFinalized): 269 encryptor.finalize() 270 cipher = Cipher( 271 cipher_factory(binascii.unhexlify(b"0" * 32)), 272 mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), 273 backend 274 ) 275 decryptor = cipher.decryptor() 276 decryptor.update(b"a" * 16) 277 with pytest.raises(AttributeError): 278 decryptor.tag 279 280 281def generate_aead_tag_exception_test(cipher_factory, mode_factory): 282 def test_aead_tag_exception(self, backend): 283 aead_tag_exception_test(backend, cipher_factory, mode_factory) 284 return test_aead_tag_exception 285 286 287def aead_tag_exception_test(backend, cipher_factory, mode_factory): 288 cipher = Cipher( 289 cipher_factory(binascii.unhexlify(b"0" * 32)), 290 mode_factory(binascii.unhexlify(b"0" * 24)), 291 backend 292 ) 293 294 with pytest.raises(ValueError): 295 mode_factory(binascii.unhexlify(b"0" * 24), b"000") 296 297 with pytest.raises(ValueError): 298 mode_factory(binascii.unhexlify(b"0" * 24), b"000000", 2) 299 300 cipher = Cipher( 301 cipher_factory(binascii.unhexlify(b"0" * 32)), 302 mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), 303 backend 304 ) 305 with pytest.raises(ValueError): 306 cipher.encryptor() 307 308 309def hkdf_derive_test(backend, algorithm, params): 310 hkdf = HKDF( 311 algorithm, 312 int(params["l"]), 313 salt=binascii.unhexlify(params["salt"]) or None, 314 info=binascii.unhexlify(params["info"]) or None, 315 backend=backend 316 ) 317 318 okm = hkdf.derive(binascii.unhexlify(params["ikm"])) 319 320 assert okm == binascii.unhexlify(params["okm"]) 321 322 323def hkdf_extract_test(backend, algorithm, params): 324 hkdf = HKDF( 325 algorithm, 326 int(params["l"]), 327 salt=binascii.unhexlify(params["salt"]) or None, 328 info=binascii.unhexlify(params["info"]) or None, 329 backend=backend 330 ) 331 332 prk = hkdf._extract(binascii.unhexlify(params["ikm"])) 333 334 assert prk == binascii.unhexlify(params["prk"]) 335 336 337def hkdf_expand_test(backend, algorithm, params): 338 hkdf = HKDFExpand( 339 algorithm, 340 int(params["l"]), 341 info=binascii.unhexlify(params["info"]) or None, 342 backend=backend 343 ) 344 345 okm = hkdf.derive(binascii.unhexlify(params["prk"])) 346 347 assert okm == binascii.unhexlify(params["okm"]) 348 349 350def generate_hkdf_test(param_loader, path, file_names, algorithm): 351 all_params = _load_all_params(path, file_names, param_loader) 352 353 all_tests = [hkdf_extract_test, hkdf_expand_test, hkdf_derive_test] 354 355 @pytest.mark.parametrize( 356 ("params", "hkdf_test"), 357 itertools.product(all_params, all_tests) 358 ) 359 def test_hkdf(self, backend, params, hkdf_test): 360 hkdf_test(backend, algorithm, params) 361 362 return test_hkdf 363 364 365def generate_kbkdf_counter_mode_test(param_loader, path, file_names): 366 all_params = _load_all_params(path, file_names, param_loader) 367 368 @pytest.mark.parametrize("params", all_params) 369 def test_kbkdf(self, backend, params): 370 kbkdf_counter_mode_test(backend, params) 371 return test_kbkdf 372 373 374def kbkdf_counter_mode_test(backend, params): 375 supported_algorithms = { 376 'hmac_sha1': hashes.SHA1, 377 'hmac_sha224': hashes.SHA224, 378 'hmac_sha256': hashes.SHA256, 379 'hmac_sha384': hashes.SHA384, 380 'hmac_sha512': hashes.SHA512, 381 } 382 383 supported_counter_locations = { 384 "before_fixed": CounterLocation.BeforeFixed, 385 "after_fixed": CounterLocation.AfterFixed, 386 } 387 388 algorithm = supported_algorithms.get(params.get('prf')) 389 if algorithm is None or not backend.hmac_supported(algorithm()): 390 pytest.skip("KBKDF does not support algorithm: {0}".format( 391 params.get('prf') 392 )) 393 394 ctr_loc = supported_counter_locations.get(params.get("ctrlocation")) 395 if ctr_loc is None or not isinstance(ctr_loc, CounterLocation): 396 pytest.skip("Does not support counter location: {0}".format( 397 params.get('ctrlocation') 398 )) 399 400 ctrkdf = KBKDFHMAC( 401 algorithm(), 402 Mode.CounterMode, 403 params['l'] // 8, 404 params['rlen'] // 8, 405 None, 406 ctr_loc, 407 None, 408 None, 409 binascii.unhexlify(params['fixedinputdata']), 410 backend=backend) 411 412 ko = ctrkdf.derive(binascii.unhexlify(params['ki'])) 413 assert binascii.hexlify(ko) == params["ko"] 414 415 416def generate_rsa_verification_test(param_loader, path, file_names, hash_alg, 417 pad_factory): 418 all_params = _load_all_params(path, file_names, param_loader) 419 all_params = [i for i in all_params 420 if i["algorithm"] == hash_alg.name.upper()] 421 422 @pytest.mark.parametrize("params", all_params) 423 def test_rsa_verification(self, backend, params): 424 rsa_verification_test(backend, params, hash_alg, pad_factory) 425 426 return test_rsa_verification 427 428 429def rsa_verification_test(backend, params, hash_alg, pad_factory): 430 public_numbers = rsa.RSAPublicNumbers( 431 e=params["public_exponent"], 432 n=params["modulus"] 433 ) 434 public_key = public_numbers.public_key(backend) 435 pad = pad_factory(params, hash_alg) 436 signature = binascii.unhexlify(params["s"]) 437 msg = binascii.unhexlify(params["msg"]) 438 if params["fail"]: 439 with pytest.raises(InvalidSignature): 440 public_key.verify( 441 signature, 442 msg, 443 pad, 444 hash_alg 445 ) 446 else: 447 public_key.verify( 448 signature, 449 msg, 450 pad, 451 hash_alg 452 ) 453 454 455def _check_rsa_private_numbers(skey): 456 assert skey 457 pkey = skey.public_numbers 458 assert pkey 459 assert pkey.e 460 assert pkey.n 461 assert skey.d 462 assert skey.p * skey.q == pkey.n 463 assert skey.dmp1 == rsa.rsa_crt_dmp1(skey.d, skey.p) 464 assert skey.dmq1 == rsa.rsa_crt_dmq1(skey.d, skey.q) 465 assert skey.iqmp == rsa.rsa_crt_iqmp(skey.p, skey.q) 466 467 468def _check_dsa_private_numbers(skey): 469 assert skey 470 pkey = skey.public_numbers 471 params = pkey.parameter_numbers 472 assert pow(params.g, skey.x, params.p) == pkey.y 473