// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::client::MlsError; use crate::crypto::{CipherSuiteProvider, HpkePublicKey, HpkeSecretKey}; use crate::group::key_schedule::kdf_derive_secret; use alloc::vec; use alloc::vec::Vec; use core::{ fmt::{self, Debug}, ops::Deref, }; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; use zeroize::Zeroizing; use super::hpke_encryption::HpkeEncryptable; #[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct PathSecret( #[mls_codec(with = "mls_rs_codec::byte_vec")] #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] Zeroizing>, ); impl Debug for PathSecret { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { mls_rs_core::debug::pretty_bytes(&self.0) .named("PathSecret") .fmt(f) } } impl Deref for PathSecret { type Target = Vec; fn deref(&self) -> &Self::Target { &self.0 } } impl From> for PathSecret { fn from(data: Vec) -> Self { PathSecret(Zeroizing::new(data)) } } impl From>> for PathSecret { fn from(data: Zeroizing>) -> Self { PathSecret(data) } } impl PathSecret { pub fn random( cipher_suite_provider: &P, ) -> Result { cipher_suite_provider .random_bytes_vec(cipher_suite_provider.kdf_extract_size()) .map(Into::into) .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) } pub fn empty(cipher_suite_provider: &P) -> Self { // Define commit_secret as the all-zero vector of the same length as a path_secret PathSecret::from(vec![0u8; cipher_suite_provider.kdf_extract_size()]) } } impl HpkeEncryptable for PathSecret { const ENCRYPT_LABEL: &'static str = "UpdatePathNode"; fn from_bytes(bytes: Vec) -> Result { Ok(Self(Zeroizing::new(bytes))) } fn get_bytes(&self) -> Result, MlsError> { Ok(self.to_vec()) } } impl PathSecret { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn to_hpke_key_pair( &self, cs: &P, ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> { let node_secret = Zeroizing::new(kdf_derive_secret(cs, self, b"node").await?); cs.kem_derive(&node_secret) .await .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) } } #[derive(Clone, Debug)] pub struct PathSecretGenerator<'a, P> { cipher_suite_provider: &'a P, last: Option, starting_with: Option, } impl<'a, P: CipherSuiteProvider> PathSecretGenerator<'a, P> { pub fn new(cipher_suite_provider: &'a P) -> Self { Self { cipher_suite_provider, last: None, starting_with: None, } } pub fn starting_with(cipher_suite_provider: &'a P, secret: PathSecret) -> Self { Self { starting_with: Some(secret), ..Self::new(cipher_suite_provider) } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn next_secret(&mut self) -> Result { let secret = if let Some(starting_with) = self.starting_with.take() { Ok(starting_with) } else if let Some(last) = self.last.take() { kdf_derive_secret(self.cipher_suite_provider, &last, b"path") .await .map(PathSecret::from) } else { PathSecret::random(self.cipher_suite_provider) }?; self.last = Some(secret.clone()); Ok(secret) } } #[cfg(test)] mod tests { use crate::{ cipher_suite::CipherSuite, client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::{ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider, }, }; use super::*; use alloc::string::String; #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; #[derive(serde::Deserialize, serde::Serialize)] struct TestCase { cipher_suite: u16, generations: Vec, } impl TestCase { #[cfg(not(mls_build_async))] #[cfg_attr(coverage_nightly, coverage(off))] fn generate() -> Vec { CipherSuite::all() .map( #[cfg_attr(coverage_nightly, coverage(off))] |cipher_suite| { let cs_provider = test_cipher_suite_provider(cipher_suite); let mut generator = PathSecretGenerator::new(&cs_provider); let generations = (0..10) .map(|_| hex::encode(&*generator.next_secret().unwrap())) .collect(); TestCase { cipher_suite: cipher_suite.into(), generations, } }, ) .collect() } #[cfg(mls_build_async)] fn generate() -> Vec { panic!("Tests cannot be generated in async mode"); } } fn load_test_cases() -> Vec { load_test_case_json!(path_secret, TestCase::generate()) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_path_secret_generation() { let cases = load_test_cases(); for test_case in cases { let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else { continue; }; let first_secret = PathSecret::from(hex::decode(&test_case.generations[0]).unwrap()); let mut generator = PathSecretGenerator::starting_with(&cs_provider, first_secret); for expected in &test_case.generations { let generated = hex::encode(&*generator.next_secret().await.unwrap()); assert_eq!(expected, &generated); } } } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_first_path_is_random() { let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut generator = PathSecretGenerator::new(&cs_provider); let first_secret = generator.next_secret().await.unwrap(); for _ in 0..100 { let mut next_generator = PathSecretGenerator::new(&cs_provider); let next_secret = next_generator.next_secret().await.unwrap(); assert_ne!(first_secret, next_secret); } } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_starting_with() { let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let secret = PathSecret::random(&cs_provider).unwrap(); let mut generator = PathSecretGenerator::starting_with(&cs_provider, secret.clone()); let first_secret = generator.next_secret().await.unwrap(); let second_secret = generator.next_secret().await.unwrap(); assert_eq!(secret, first_secret); assert_ne!(first_secret, second_secret); } #[test] fn test_empty_path_secret() { for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { let cs_provider = test_cipher_suite_provider(cipher_suite); let empty = PathSecret::empty(&cs_provider); assert_eq!( empty, PathSecret::from(vec![0u8; cs_provider.kdf_extract_size()]) ) } } #[test] fn test_random_path_secret() { let cs_provider = test_cipher_suite_provider(CipherSuite::P256_AES128); let initial = PathSecret::random(&cs_provider).unwrap(); for _ in 0..100 { let next = PathSecret::random(&cs_provider).unwrap(); assert_ne!(next, initial); } } }