// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use alloc::vec::Vec; use itertools::Itertools; use crate::crypto::HpkeContextR; use super::{ CipherSuiteProvider, CryptoProvider, HpkeCiphertext, HpkeContextS, HpkePublicKey, HpkeSecretKey, }; const PATH: &str = concat!( env!("CARGO_MANIFEST_DIR"), "/test_data/crypto_provider.json" ); #[cfg(any(target_arch = "wasm32", not(feature = "std")))] const SERIALIZED_TEST_SUITES: &[u8] = include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/test_data/crypto_provider.json" )); pub use hpke_rfc_conformance::{ verify_hpke_context_tests, verify_hpke_encap_tests, EncapOutput, TestHpke, }; pub const DATA_SIZES: [usize; 5] = [0, 1, 16, 123, 2000]; #[derive(serde::Serialize, serde::Deserialize, Default)] struct TestSuite { cipher_suite: u16, #[serde(default)] signature_tests: Vec, #[serde(default)] aead_tests: Vec, #[serde(default)] hpke_tests: HpkeTestCases, #[serde(default)] hkdf_tests: Vec, #[serde(default)] mac_tests: Vec, #[serde(default)] hash_tests: Vec, } #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))] #[cfg_attr(coverage_nightly, coverage(off))] pub fn generate_tests(crypto: &C) { for cs in crypto.supported_cipher_suites() { crypto.cipher_suite_provider(cs).unwrap(); } let mut test_suites = create_or_load_tests(crypto); for test_suite in test_suites.iter_mut() { let cs = test_suite.cipher_suite.into(); let cs = crypto.cipher_suite_provider(cs).unwrap(); test_suite.signature_tests = generate_signature_tests(&cs); test_suite.hpke_tests = generate_hpke_tests(&cs); test_suite.hkdf_tests = generate_hkdf_tests(&cs); } std::fs::write(PATH, serde_json::to_string_pretty(&test_suites).unwrap()).unwrap(); } #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))] #[cfg_attr(coverage_nightly, coverage(off))] fn create_or_load_tests(crypto: &C) -> Vec { if std::path::Path::new(PATH).exists() { serde_json::from_slice(&std::fs::read(PATH).unwrap()).unwrap() } else { crypto .supported_cipher_suites() .into_iter() .map(|cipher_suite| TestSuite { cipher_suite: cipher_suite.into(), ..Default::default() }) .collect() } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn verify_tests(crypto: &C, signature_secret_key_compatible: bool) { #[cfg(any(target_arch = "wasm32", not(feature = "std")))] let test_suites: Vec = serde_json::from_slice(SERIALIZED_TEST_SUITES).unwrap(); #[cfg(all(not(target_arch = "wasm32"), feature = "std"))] let test_suites: Vec = serde_json::from_slice(&std::fs::read(PATH).unwrap()).unwrap(); for test_suite in test_suites { let test_cs = test_suite.cipher_suite.into(); let Some(cs) = crypto.cipher_suite_provider(test_cs) else { continue; }; assert_eq!(cs.cipher_suite(), test_cs); verify_hkdf_tests(&cs, test_suite.hkdf_tests).await; verify_aead_tests(&cs, test_suite.aead_tests).await; verify_mac_tests(&cs, test_suite.mac_tests).await; verify_hpke_tests(&cs, test_suite.hpke_tests).await; verify_signature_tests( &cs, test_suite.signature_tests, signature_secret_key_compatible, ) .await; verify_hash_tests(&cs, test_suite.hash_tests).await; } } #[derive(serde::Serialize, serde::Deserialize)] struct SignatureTestCase { #[serde(with = "hex::serde")] secret: Vec, #[serde(with = "hex::serde")] public: Vec, #[serde(with = "hex::serde")] data: Vec, #[serde(with = "hex::serde")] signature: Vec, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_signature_tests( cs: &C, test_cases: Vec, secret_key_compatible: bool, ) { // Checks that `cs` can sign and verify let generated = generate_signature_tests(cs).await; for (test_case, is_generated) in test_cases .into_iter() .map(|tc| (tc, false)) .chain(generated.into_iter().map(|tc| (tc, true))) { let public = test_case.public.into(); // Checks that `cs` can verify signatures generated by itself and another implementation cs.verify(&public, &test_case.signature, &test_case.data) .await .unwrap(); if is_generated || secret_key_compatible { let secret = test_case.secret.into(); let derived = cs.signature_key_derive_public(&secret).await.unwrap(); cs.sign(&secret, b"hello world").await.unwrap(); assert_eq!(derived, public); } } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(coverage_nightly, coverage(off))] async fn generate_signature_tests(cs: &C) -> Vec { let mut tests = Vec::new(); for data_size in DATA_SIZES { let data = cs.random_bytes_vec(data_size).unwrap(); let (secret, public) = cs.signature_key_generate().await.unwrap(); let signature = cs.sign(&secret, &data).await.unwrap(); tests.push(SignatureTestCase { secret: secret.to_vec(), public: public.to_vec(), data, signature, }); } tests } // Test vectors from the RFC #[derive(serde::Deserialize, serde::Serialize)] struct AeadTestCase { #[serde(with = "hex::serde")] pub key: Vec, #[serde(with = "hex::serde")] pub iv: Vec, #[serde(with = "hex::serde")] pub ct: Vec, #[serde(with = "hex::serde")] pub aad: Vec, #[serde(with = "hex::serde")] pub pt: Vec, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_aead_tests(cs: &C, test_cases: Vec) { for case in test_cases { let ciphertext = cs .aead_seal(&case.key, &case.pt, Some(&case.aad), &case.iv) .await .unwrap(); assert_eq!(ciphertext, case.ct); let plaintext = cs .aead_open(&case.key, &ciphertext, Some(&case.aad), &case.iv) .await .unwrap(); assert_eq!(plaintext.to_vec(), case.pt); } } #[derive(serde::Serialize, serde::Deserialize, Default)] struct HpkeTestCases { #[serde(with = "hex::serde")] ikm: Vec, #[serde(with = "hex::serde")] secret: Vec, #[serde(with = "hex::serde")] public: Vec, seal_tests: Vec, export_tests: Vec, } #[derive(serde::Serialize, serde::Deserialize)] struct HpkeSealTestCase { #[serde(with = "hex::serde")] plaintext: Vec, #[serde(with = "hex::serde")] info: Vec, #[serde(with = "hex::serde")] aad: Vec, // Seal and open #[serde(with = "hex::serde")] sealed_kem_output: Vec, #[serde(with = "hex::serde")] sealed_ciphertext: Vec, // Setup s and r #[serde(with = "hex::serde")] setup_s_kem_output: Vec, #[serde(with = "hex::serde")] setup_s_ciphertext: Vec, } #[derive(serde::Serialize, serde::Deserialize)] struct HpkeExportTestCase { #[serde(with = "hex::serde")] info: Vec, #[serde(with = "hex::serde")] kem_output: Vec, #[serde(with = "hex::serde")] exporter_context: Vec, exported_len: usize, #[serde(with = "hex::serde")] exported: Vec, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_hpke_tests(cs: &C, test_cases: HpkeTestCases) { let generated = generate_hpke_tests(cs).await; verify_hpke_test(cs, generated).await; verify_hpke_test(cs, test_cases).await; } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_hpke_test(cs: &C, test_cases: HpkeTestCases) { let (secret, public) = cs.kem_derive(&test_cases.ikm).await.unwrap(); assert_eq!(&secret, &test_cases.secret.into()); assert_eq!(&public, &test_cases.public.into()); for test in test_cases.seal_tests { let ct = HpkeCiphertext { kem_output: test.sealed_kem_output.clone(), ciphertext: test.sealed_ciphertext.clone(), }; test_open_ciphertext(cs, &secret, &public, &ct, &test).await; let ct = HpkeCiphertext { kem_output: test.setup_s_kem_output.clone(), ciphertext: test.setup_s_ciphertext.clone(), }; test_open_ciphertext(cs, &secret, &public, &ct, &test).await; } for test in test_cases.export_tests { let context_r = cs .hpke_setup_r(&test.kem_output, &secret, &public, &test.info) .await .unwrap(); let exported = context_r .export(&test.exporter_context, test.exported_len) .await .unwrap(); assert_eq!(exported, test.exported); } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn test_open_ciphertext( cs: &C, secret: &HpkeSecretKey, public: &HpkePublicKey, ct: &HpkeCiphertext, test: &HpkeSealTestCase, ) { let aad = (!test.aad.is_empty()).then_some(test.aad.as_slice()); let opened = cs .hpke_open(ct, secret, public, &test.info, aad) .await .unwrap(); assert_eq!(&opened, &test.plaintext); let mut context_r = cs .hpke_setup_r(&ct.kem_output, secret, public, &test.info) .await .unwrap(); let opened = context_r.open(aad, &ct.ciphertext).await.unwrap(); assert_eq!(&opened, &test.plaintext); } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(coverage_nightly, coverage(off))] async fn generate_hpke_tests(cs: &C) -> HpkeTestCases { let ikm = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(); let (secret, public) = cs.kem_derive(&ikm).await.unwrap(); let sizes_iter = DATA_SIZES.iter().copied(); let mut seal_tests = Vec::new(); for ((pt_size, info_size), aad_size) in sizes_iter .clone() .skip(1) .cartesian_product(sizes_iter.clone()) .cartesian_product(sizes_iter.clone()) { let plaintext = cs.random_bytes_vec(pt_size).unwrap(); let info = cs.random_bytes_vec(info_size).unwrap(); let aad = cs.random_bytes_vec(aad_size).unwrap(); let sealed = cs .hpke_seal(&public, &info, (aad_size > 0).then_some(&aad), &plaintext) .await .unwrap(); let (setup_s_kem_output, mut context_s) = cs.hpke_setup_s(&public, &info).await.unwrap(); let setup_s_ciphertext = context_s .seal((aad_size > 0).then_some(&aad), &plaintext) .await .unwrap(); seal_tests.push(HpkeSealTestCase { plaintext, info, aad, sealed_kem_output: sealed.kem_output, sealed_ciphertext: sealed.ciphertext, setup_s_kem_output, setup_s_ciphertext, }) } let mut export_tests = Vec::new(); for ((context_len, exported_len), info_size) in sizes_iter .clone() .cartesian_product(sizes_iter.clone().skip(1)) .cartesian_product(sizes_iter) { let exporter_context = cs.random_bytes_vec(context_len).unwrap(); let info = cs.random_bytes_vec(info_size).unwrap(); let (kem_output, context) = cs.hpke_setup_s(&public, &info).await.unwrap(); let exported = context .export(&exporter_context, exported_len) .await .unwrap(); export_tests.push(HpkeExportTestCase { info, kem_output, exporter_context, exported_len, exported, }); } HpkeTestCases { ikm, secret: secret.to_vec(), public: public.to_vec(), seal_tests, export_tests, } } #[derive(serde::Deserialize, serde::Serialize)] struct HkdfTestCase { #[serde(with = "hex::serde")] pub ikm: Vec, #[serde(with = "hex::serde")] pub salt: Vec, #[serde(with = "hex::serde")] pub info: Vec, pub len: usize, #[serde(with = "hex::serde")] pub prk: Vec, #[serde(with = "hex::serde")] pub okm: Vec, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_hkdf_tests(cs: &C, test_cases: Vec) { for case in test_cases { let extracted = cs.kdf_extract(&case.salt, &case.ikm).await.unwrap(); assert_eq!(extracted.to_vec(), case.prk); let expanded = cs .kdf_expand(&case.prk, &case.info, case.len) .await .unwrap(); assert_eq!(expanded.to_vec(), case.okm); } } #[cfg(all(not(mls_build_async), not(target_arch = "wasm32"), feature = "std"))] #[cfg_attr(coverage_nightly, coverage(off))] fn generate_hkdf_tests(cs: &C) -> Vec { let iter = DATA_SIZES.iter().copied(); let iter = iter .clone() .skip(1) .cartesian_product(iter.clone()) .cartesian_product(iter.clone()) .cartesian_product(iter.skip(1)); iter.map(|(((ikm_size, salt_size), info_size), len)| { let ikm = cs.random_bytes_vec(ikm_size).unwrap(); let salt = cs.random_bytes_vec(salt_size).unwrap(); let info = cs.random_bytes_vec(info_size).unwrap(); let prk = cs.kdf_extract(&salt, &ikm).unwrap().to_vec(); let okm = cs.kdf_expand(&prk, &info, len).unwrap().to_vec(); HkdfTestCase { ikm, salt, info, len, prk, okm, } }) .collect() } // Test vectors from RFC 4231 #[derive(serde::Deserialize, serde::Serialize)] struct MacTestCase { #[serde(with = "hex::serde")] key: Vec, #[serde(with = "hex::serde")] data: Vec, #[serde(with = "hex::serde")] tag: Vec, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_mac_tests(cs: &C, test_cases: Vec) { for case in test_cases { let computed = cs.mac(&case.key, &case.data).await.unwrap(); assert_eq!(computed, case.tag); } } #[derive(serde::Deserialize, serde::Serialize)] struct HashTestCase { #[serde(with = "hex::serde")] input: Vec, #[serde(with = "hex::serde")] output: Vec, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_hash_tests(cs: &C, test_cases: Vec) { for case in test_cases { let computed = cs.hash(&case.input).await.unwrap(); assert_eq!(computed, case.output); } } mod hpke_rfc_conformance { use alloc::vec::Vec; use crate::crypto::{CipherSuite, HpkeContextR, HpkeContextS, HpkeModeId}; #[derive(serde::Deserialize, Debug, Clone)] pub struct TestCaseAlgo { pub kem_id: u16, pub kdf_id: u16, pub aead_id: u16, pub mode: u8, } impl TestCaseAlgo { fn cipher_suite(&self) -> Option { if ![HpkeModeId::Base as u8, HpkeModeId::Psk as u8].contains(&self.mode) { return None; } match (self.kem_id, self.kdf_id, self.aead_id) { (0x0010, 0x0001, 0x0001) => Some(CipherSuite::P256_AES128), (0x0011, 0x0002, 0x0002) => Some(CipherSuite::P384_AES256), (0x0012, 0x0003, 0x0002) => Some(CipherSuite::P521_AES256), (0x0020, 0x0001, 0x0001) => Some(CipherSuite::CURVE25519_AES128), (0x0020, 0x0001, 0x0003) => Some(CipherSuite::CURVE25519_CHACHA), (0x0021, 0x0003, 0x0002) => Some(CipherSuite::CURVE448_AES256), (0x0021, 0x0003, 0x0003) => Some(CipherSuite::CURVE448_CHACHA), _ => None, } } } #[derive(serde::Deserialize, Debug)] struct TestCase { #[serde(flatten)] algo: TestCaseAlgo, #[serde(with = "hex::serde", rename(deserialize = "pkRm"))] pk_rm: Vec, #[serde(with = "hex::serde", rename(deserialize = "skRm"))] sk_rm: Vec, #[serde(with = "hex::serde", rename(deserialize = "ikmE"))] ikm_e: Vec, #[serde(with = "hex::serde")] shared_secret: Vec, #[serde(with = "hex::serde")] enc: Vec, #[serde(with = "hex::serde")] exporter_secret: Vec, #[serde(with = "hex::serde")] base_nonce: Vec, #[serde(with = "hex::serde")] key: Vec, encryptions: Vec, exports: Vec, } #[derive(serde::Deserialize, Debug)] struct EncryptionTestCase { #[serde(with = "hex::serde", rename = "pt")] plaintext: Vec, #[serde(with = "hex::serde")] aad: Vec, #[serde(with = "hex::serde", rename = "ct")] ciphertext: Vec, } #[derive(serde::Deserialize, Debug)] struct ExportTestCase { #[serde(with = "hex::serde")] exporter_context: Vec, #[serde(rename = "L")] length: usize, #[serde(with = "hex::serde")] exported_value: Vec, } #[cfg(any(target_arch = "wasm32", not(feature = "std")))] fn get_test_cases() -> Vec { let bytes = include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/test_data/test_hpke.json" )); serde_json::from_slice(bytes).unwrap() } #[cfg(all(not(target_arch = "wasm32"), feature = "std"))] fn get_test_cases() -> Vec { let path = concat!(env!("CARGO_MANIFEST_DIR"), "/test_data/test_hpke.json"); serde_json::from_slice(&std::fs::read(path).unwrap()).unwrap() } pub struct EncapOutput { pub enc: Vec, pub shared_secret: Vec, } impl EncapOutput { pub fn new(enc: Vec, shared_secret: Vec) -> Self { Self { enc, shared_secret } } } pub trait TestHpke { type ContextS: HpkeContextS; type ContextR: HpkeContextR; fn hpke_context( &self, key: Vec, base_nonce: Vec, exporter_secret: Vec, ) -> (Self::ContextS, Self::ContextR); fn encap(&mut self, ikm_e: Vec, pk_rm: Vec) -> EncapOutput; fn decap(&mut self, enc: Vec, sk_rm: Vec, pk_rm: Vec) -> Vec; } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn verify_hpke_context_tests(hpke: &C, cipher_suite: CipherSuite) { for test_case in get_test_cases() .into_iter() .filter(|tc| matches!(tc.algo.cipher_suite(), Some(c) if c == cipher_suite)) { let (mut context_s, mut context_r) = hpke.hpke_context( test_case.key, test_case.base_nonce, test_case.exporter_secret, ); for enc_test_case in test_case.encryptions { // Encrypt let ct = context_s .seal(Some(&enc_test_case.aad), &enc_test_case.plaintext) .await .unwrap(); assert_eq!(ct, enc_test_case.ciphertext); // Decrypt let pt = context_r.open(Some(&enc_test_case.aad), &ct).await.unwrap(); assert_eq!(pt, enc_test_case.plaintext); } for test in test_case.exports { let exported_s = context_s.export(&test.exporter_context, test.length).await; assert_eq!(exported_s.unwrap(), test.exported_value); let exported_r = context_r.export(&test.exporter_context, test.length).await; assert_eq!(exported_r.unwrap(), test.exported_value); } } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn verify_hpke_encap_tests(hpke: &mut C, cipher_suite: CipherSuite) { for test_case in get_test_cases() .into_iter() .filter(|tc| matches!(tc.algo.cipher_suite(), Some(c) if c == cipher_suite)) { let out = hpke.encap(test_case.ikm_e, test_case.pk_rm.clone()); assert_eq!(&out.enc, &test_case.enc); assert_eq!(&out.shared_secret, &test_case.shared_secret); let shared_secret = hpke.decap(test_case.enc, test_case.sk_rm, test_case.pk_rm); assert_eq!(shared_secret, test_case.shared_secret); } } }