1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use core::{ 7 fmt::{self, Debug}, 8 ops::{Deref, DerefMut}, 9 }; 10 11 use zeroize::Zeroizing; 12 13 use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider}; 14 15 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 16 use mls_rs_core::error::IntoAnyError; 17 18 use super::key_schedule::kdf_expand_with_label; 19 20 pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024; 21 22 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 23 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 24 #[repr(u8)] 25 enum SecretTreeNode { 26 Secret(TreeSecret) = 0u8, 27 Ratchet(SecretRatchets) = 1u8, 28 } 29 30 impl SecretTreeNode { into_secret(self) -> Option<TreeSecret>31 fn into_secret(self) -> Option<TreeSecret> { 32 if let SecretTreeNode::Secret(secret) = self { 33 Some(secret) 34 } else { 35 None 36 } 37 } 38 } 39 40 #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)] 41 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 42 struct TreeSecret( 43 #[mls_codec(with = "mls_rs_codec::byte_vec")] 44 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] 45 Zeroizing<Vec<u8>>, 46 ); 47 48 impl Debug for TreeSecret { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 50 mls_rs_core::debug::pretty_bytes(&self.0) 51 .named("TreeSecret") 52 .fmt(f) 53 } 54 } 55 56 impl Deref for TreeSecret { 57 type Target = Vec<u8>; 58 deref(&self) -> &Self::Target59 fn deref(&self) -> &Self::Target { 60 &self.0 61 } 62 } 63 64 impl DerefMut for TreeSecret { deref_mut(&mut self) -> &mut Self::Target65 fn deref_mut(&mut self) -> &mut Self::Target { 66 &mut self.0 67 } 68 } 69 70 impl AsRef<[u8]> for TreeSecret { as_ref(&self) -> &[u8]71 fn as_ref(&self) -> &[u8] { 72 &self.0 73 } 74 } 75 76 impl From<Vec<u8>> for TreeSecret { from(vec: Vec<u8>) -> Self77 fn from(vec: Vec<u8>) -> Self { 78 TreeSecret(Zeroizing::new(vec)) 79 } 80 } 81 82 impl From<Zeroizing<Vec<u8>>> for TreeSecret { from(vec: Zeroizing<Vec<u8>>) -> Self83 fn from(vec: Zeroizing<Vec<u8>>) -> Self { 84 TreeSecret(vec) 85 } 86 } 87 88 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)] 89 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 90 struct TreeSecretsVec<T: TreeIndex> { 91 inner: LargeMap<T, SecretTreeNode>, 92 } 93 94 impl<T: TreeIndex> TreeSecretsVec<T> { set_node(&mut self, index: T, value: SecretTreeNode)95 fn set_node(&mut self, index: T, value: SecretTreeNode) { 96 self.inner.insert(index, value); 97 } 98 take_node(&mut self, index: &T) -> Option<SecretTreeNode>99 fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> { 100 self.inner.remove(index) 101 } 102 } 103 104 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] 105 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 106 pub struct SecretTree<T: TreeIndex> { 107 known_secrets: TreeSecretsVec<T>, 108 leaf_count: T, 109 } 110 111 impl<T: TreeIndex> SecretTree<T> { empty() -> SecretTree<T>112 pub(crate) fn empty() -> SecretTree<T> { 113 SecretTree { 114 known_secrets: Default::default(), 115 leaf_count: T::zero(), 116 } 117 } 118 } 119 120 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 121 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 122 pub struct SecretRatchets { 123 pub application: SecretKeyRatchet, 124 pub handshake: SecretKeyRatchet, 125 } 126 127 impl SecretRatchets { 128 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>129 pub async fn message_key_generation<P: CipherSuiteProvider>( 130 &mut self, 131 cipher_suite_provider: &P, 132 generation: u32, 133 key_type: KeyType, 134 ) -> Result<MessageKeyData, MlsError> { 135 match key_type { 136 KeyType::Handshake => { 137 self.handshake 138 .get_message_key(cipher_suite_provider, generation) 139 .await 140 } 141 KeyType::Application => { 142 self.application 143 .get_message_key(cipher_suite_provider, generation) 144 .await 145 } 146 } 147 } 148 149 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>150 pub async fn next_message_key<P: CipherSuiteProvider>( 151 &mut self, 152 cipher_suite: &P, 153 key_type: KeyType, 154 ) -> Result<MessageKeyData, MlsError> { 155 match key_type { 156 KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await, 157 KeyType::Application => self.application.next_message_key(cipher_suite).await, 158 } 159 } 160 } 161 162 impl<T: TreeIndex> SecretTree<T> { new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T>163 pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> { 164 let mut known_secrets = TreeSecretsVec::default(); 165 166 let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret)); 167 known_secrets.set_node(leaf_count.root(), root_secret); 168 169 Self { 170 known_secrets, 171 leaf_count, 172 } 173 } 174 175 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] consume_node<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, index: &T, ) -> Result<(), MlsError>176 async fn consume_node<P: CipherSuiteProvider>( 177 &mut self, 178 cipher_suite_provider: &P, 179 index: &T, 180 ) -> Result<(), MlsError> { 181 let node = self.known_secrets.take_node(index); 182 183 if let Some(secret) = node.and_then(|n| n.into_secret()) { 184 let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?; 185 let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?; 186 187 let left_secret = 188 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None) 189 .await?; 190 191 let right_secret = 192 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None) 193 .await?; 194 195 self.known_secrets 196 .set_node(left_index, SecretTreeNode::Secret(left_secret.into())); 197 198 self.known_secrets 199 .set_node(right_index, SecretTreeNode::Secret(right_secret.into())); 200 } 201 202 Ok(()) 203 } 204 205 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] take_leaf_ratchet<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: &T, ) -> Result<SecretRatchets, MlsError>206 async fn take_leaf_ratchet<P: CipherSuiteProvider>( 207 &mut self, 208 cipher_suite: &P, 209 leaf_index: &T, 210 ) -> Result<SecretRatchets, MlsError> { 211 let node_index = leaf_index; 212 213 let node = match self.known_secrets.take_node(node_index) { 214 Some(node) => node, 215 None => { 216 // Start at the root node and work your way down consuming any intermediates needed 217 for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() { 218 self.consume_node(cipher_suite, &i.path).await?; 219 } 220 221 self.known_secrets 222 .take_node(node_index) 223 .ok_or(MlsError::InvalidLeafConsumption)? 224 } 225 }; 226 227 Ok(match node { 228 SecretTreeNode::Ratchet(ratchet) => ratchet, 229 SecretTreeNode::Secret(secret) => SecretRatchets { 230 application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application) 231 .await?, 232 handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?, 233 }, 234 }) 235 } 236 237 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>238 pub async fn next_message_key<P: CipherSuiteProvider>( 239 &mut self, 240 cipher_suite: &P, 241 leaf_index: T, 242 key_type: KeyType, 243 ) -> Result<MessageKeyData, MlsError> { 244 let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; 245 let res = ratchet.next_message_key(cipher_suite, key_type).await?; 246 247 self.known_secrets 248 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); 249 250 Ok(res) 251 } 252 253 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, generation: u32, ) -> Result<MessageKeyData, MlsError>254 pub async fn message_key_generation<P: CipherSuiteProvider>( 255 &mut self, 256 cipher_suite: &P, 257 leaf_index: T, 258 key_type: KeyType, 259 generation: u32, 260 ) -> Result<MessageKeyData, MlsError> { 261 let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; 262 263 let res = ratchet 264 .message_key_generation(cipher_suite, generation, key_type) 265 .await?; 266 267 self.known_secrets 268 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); 269 270 Ok(res) 271 } 272 } 273 274 #[derive(Clone, Copy)] 275 pub enum KeyType { 276 Handshake, 277 Application, 278 } 279 280 #[cfg_attr( 281 all(feature = "ffi", not(test)), 282 safer_ffi_gen::ffi_type(clone, opaque) 283 )] 284 #[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)] 285 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 286 /// AEAD key derived by the MLS secret tree. 287 pub struct MessageKeyData { 288 #[mls_codec(with = "mls_rs_codec::byte_vec")] 289 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] 290 pub(crate) nonce: Zeroizing<Vec<u8>>, 291 #[mls_codec(with = "mls_rs_codec::byte_vec")] 292 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] 293 pub(crate) key: Zeroizing<Vec<u8>>, 294 pub(crate) generation: u32, 295 } 296 297 impl Debug for MessageKeyData { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result298 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 299 f.debug_struct("MessageKeyData") 300 .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce)) 301 .field("key", &mls_rs_core::debug::pretty_bytes(&self.key)) 302 .field("generation", &self.generation) 303 .finish() 304 } 305 } 306 307 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 308 impl MessageKeyData { 309 /// AEAD nonce. 310 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] nonce(&self) -> &[u8]311 pub fn nonce(&self) -> &[u8] { 312 &self.nonce 313 } 314 315 /// AEAD key. 316 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] key(&self) -> &[u8]317 pub fn key(&self) -> &[u8] { 318 &self.key 319 } 320 321 /// Generation of this key within the key schedule. 322 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] generation(&self) -> u32323 pub fn generation(&self) -> u32 { 324 self.generation 325 } 326 } 327 328 #[derive(Debug, Clone, PartialEq)] 329 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 330 pub struct SecretKeyRatchet { 331 secret: TreeSecret, 332 generation: u32, 333 #[cfg(feature = "out_of_order")] 334 history: LargeMap<u32, MessageKeyData>, 335 } 336 337 impl MlsSize for SecretKeyRatchet { mls_encoded_len(&self) -> usize338 fn mls_encoded_len(&self) -> usize { 339 let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret) 340 + self.generation.mls_encoded_len(); 341 342 #[cfg(feature = "out_of_order")] 343 return len + mls_rs_codec::iter::mls_encoded_len(self.history.values()); 344 #[cfg(not(feature = "out_of_order"))] 345 return len; 346 } 347 } 348 349 #[cfg(feature = "out_of_order")] 350 impl MlsEncode for SecretKeyRatchet { mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>351 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { 352 mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?; 353 self.generation.mls_encode(writer)?; 354 mls_rs_codec::iter::mls_encode(self.history.values(), writer) 355 } 356 } 357 358 #[cfg(not(feature = "out_of_order"))] 359 impl MlsEncode for SecretKeyRatchet { mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>360 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { 361 mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?; 362 self.generation.mls_encode(writer) 363 } 364 } 365 366 impl MlsDecode for SecretKeyRatchet { mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>367 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> { 368 Ok(Self { 369 secret: mls_rs_codec::byte_vec::mls_decode(reader)?, 370 generation: u32::mls_decode(reader)?, 371 #[cfg(feature = "out_of_order")] 372 history: mls_rs_codec::iter::mls_decode_collection(reader, |data| { 373 let mut items = LargeMap::default(); 374 375 while !data.is_empty() { 376 let item = MessageKeyData::mls_decode(data)?; 377 items.insert(item.generation, item); 378 } 379 380 Ok(items) 381 })?, 382 }) 383 } 384 } 385 386 impl SecretKeyRatchet { 387 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] new<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], key_type: KeyType, ) -> Result<Self, MlsError>388 async fn new<P: CipherSuiteProvider>( 389 cipher_suite_provider: &P, 390 secret: &[u8], 391 key_type: KeyType, 392 ) -> Result<Self, MlsError> { 393 let label = match key_type { 394 KeyType::Handshake => b"handshake".as_slice(), 395 KeyType::Application => b"application".as_slice(), 396 }; 397 398 let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None) 399 .await 400 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 401 402 Ok(Self { 403 secret: TreeSecret::from(secret), 404 generation: 0, 405 #[cfg(feature = "out_of_order")] 406 history: Default::default(), 407 }) 408 } 409 410 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, ) -> Result<MessageKeyData, MlsError>411 async fn get_message_key<P: CipherSuiteProvider>( 412 &mut self, 413 cipher_suite_provider: &P, 414 generation: u32, 415 ) -> Result<MessageKeyData, MlsError> { 416 #[cfg(feature = "out_of_order")] 417 if generation < self.generation { 418 return self 419 .history 420 .remove_entry(&generation) 421 .map(|(_, mk)| mk) 422 .ok_or(MlsError::KeyMissing(generation)); 423 } 424 425 #[cfg(not(feature = "out_of_order"))] 426 if generation < self.generation { 427 return Err(MlsError::KeyMissing(generation)); 428 } 429 430 let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY; 431 432 if generation > max_generation_allowed { 433 return Err(MlsError::InvalidFutureGeneration(generation)); 434 } 435 436 #[cfg(not(feature = "out_of_order"))] 437 while self.generation < generation { 438 self.next_message_key(cipher_suite_provider)?; 439 } 440 441 #[cfg(feature = "out_of_order")] 442 while self.generation < generation { 443 let key_data = self.next_message_key(cipher_suite_provider).await?; 444 self.history.insert(key_data.generation, key_data); 445 } 446 447 self.next_message_key(cipher_suite_provider).await 448 } 449 450 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, ) -> Result<MessageKeyData, MlsError>451 async fn next_message_key<P: CipherSuiteProvider>( 452 &mut self, 453 cipher_suite_provider: &P, 454 ) -> Result<MessageKeyData, MlsError> { 455 let generation = self.generation; 456 457 let key = MessageKeyData { 458 nonce: self 459 .derive_secret( 460 cipher_suite_provider, 461 b"nonce", 462 cipher_suite_provider.aead_nonce_size(), 463 ) 464 .await?, 465 key: self 466 .derive_secret( 467 cipher_suite_provider, 468 b"key", 469 cipher_suite_provider.aead_key_size(), 470 ) 471 .await?, 472 generation, 473 }; 474 475 self.secret = self 476 .derive_secret( 477 cipher_suite_provider, 478 b"secret", 479 cipher_suite_provider.kdf_extract_size(), 480 ) 481 .await? 482 .into(); 483 484 self.generation = generation + 1; 485 486 Ok(key) 487 } 488 489 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] derive_secret<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, label: &[u8], len: usize, ) -> Result<Zeroizing<Vec<u8>>, MlsError>490 async fn derive_secret<P: CipherSuiteProvider>( 491 &self, 492 cipher_suite_provider: &P, 493 label: &[u8], 494 len: usize, 495 ) -> Result<Zeroizing<Vec<u8>>, MlsError> { 496 kdf_expand_with_label( 497 cipher_suite_provider, 498 self.secret.as_ref(), 499 label, 500 &self.generation.to_be_bytes(), 501 Some(len), 502 ) 503 .await 504 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 505 } 506 } 507 508 #[cfg(test)] 509 pub(crate) mod test_utils { 510 use alloc::{string::String, vec::Vec}; 511 use mls_rs_core::crypto::CipherSuiteProvider; 512 use zeroize::Zeroizing; 513 514 use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex}; 515 516 use super::{KeyType, SecretKeyRatchet, SecretTree}; 517 get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T>518 pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> { 519 SecretTree::new(leaf_count, Zeroizing::new(secret)) 520 } 521 522 impl SecretTree<u32> { get_root_secret(&self) -> Vec<u8>523 pub(crate) fn get_root_secret(&self) -> Vec<u8> { 524 self.known_secrets 525 .clone() 526 .take_node(&self.leaf_count.root()) 527 .unwrap() 528 .into_secret() 529 .unwrap() 530 .to_vec() 531 } 532 } 533 534 #[derive(Debug, serde::Serialize, serde::Deserialize)] 535 pub struct RatchetInteropTestCase { 536 #[serde(with = "hex::serde")] 537 secret: Vec<u8>, 538 label: String, 539 generation: u32, 540 length: usize, 541 #[serde(with = "hex::serde")] 542 out: Vec<u8>, 543 } 544 545 #[derive(Debug, serde::Serialize, serde::Deserialize)] 546 pub struct InteropTestCase { 547 cipher_suite: u16, 548 derive_tree_secret: RatchetInteropTestCase, 549 } 550 551 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_basic_crypto_test_vectors()552 async fn test_basic_crypto_test_vectors() { 553 let test_cases: Vec<InteropTestCase> = 554 load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new()); 555 556 for test_case in test_cases { 557 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) { 558 test_case.derive_tree_secret.verify(&cs).await 559 } 560 } 561 } 562 563 impl RatchetInteropTestCase { 564 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] verify<P: CipherSuiteProvider>(&self, cs: &P)565 pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) { 566 let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application) 567 .await 568 .unwrap(); 569 570 ratchet.secret = self.secret.clone().into(); 571 ratchet.generation = self.generation; 572 573 let computed = ratchet 574 .derive_secret(cs, self.label.as_bytes(), self.length) 575 .await 576 .unwrap(); 577 578 assert_eq!(&computed.to_vec(), &self.out); 579 } 580 } 581 } 582 583 #[cfg(test)] 584 mod tests { 585 use alloc::vec; 586 587 use crate::{ 588 cipher_suite::CipherSuite, 589 client::test_utils::TEST_CIPHER_SUITE, 590 crypto::test_utils::{ 591 test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider, 592 }, 593 tree_kem::node::NodeIndex, 594 }; 595 596 #[cfg(not(mls_build_async))] 597 use crate::group::test_utils::random_bytes; 598 599 use super::{test_utils::get_test_tree, *}; 600 601 use assert_matches::assert_matches; 602 603 #[cfg(target_arch = "wasm32")] 604 use wasm_bindgen_test::wasm_bindgen_test as test; 605 606 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_tree()607 async fn test_secret_tree() { 608 test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await; 609 test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await; 610 } 611 612 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_secret_tree_custom<T: TreeIndex>( leaf_count: T, leaves_to_check: Vec<T>, all_deleted: bool, )613 async fn test_secret_tree_custom<T: TreeIndex>( 614 leaf_count: T, 615 leaves_to_check: Vec<T>, 616 all_deleted: bool, 617 ) { 618 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 619 let cs_provider = test_cipher_suite_provider(cipher_suite); 620 621 let test_secret = vec![0u8; cs_provider.kdf_extract_size()]; 622 let mut test_tree = get_test_tree(test_secret, leaf_count.clone()); 623 624 let mut secrets = Vec::<SecretRatchets>::new(); 625 626 for i in &leaves_to_check { 627 let secret = test_tree 628 .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i) 629 .await 630 .unwrap(); 631 632 secrets.push(secret); 633 } 634 635 // Verify the tree is now completely empty 636 assert!(!all_deleted || test_tree.known_secrets.inner.is_empty()); 637 638 // Verify that all the secrets are unique 639 let count = secrets.len(); 640 secrets.dedup(); 641 assert_eq!(count, secrets.len()); 642 } 643 } 644 645 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_key_ratchet()646 async fn test_secret_key_ratchet() { 647 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 648 let provider = test_cipher_suite_provider(cipher_suite); 649 650 let mut app_ratchet = SecretKeyRatchet::new( 651 &provider, 652 &vec![0u8; provider.kdf_extract_size()], 653 KeyType::Application, 654 ) 655 .await 656 .unwrap(); 657 658 let mut handshake_ratchet = SecretKeyRatchet::new( 659 &provider, 660 &vec![0u8; provider.kdf_extract_size()], 661 KeyType::Handshake, 662 ) 663 .await 664 .unwrap(); 665 666 let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap(); 667 let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap(); 668 let app_keys = vec![app_key_one, app_key_two]; 669 670 let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap(); 671 let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap(); 672 let handshake_keys = vec![handshake_key_one, handshake_key_two]; 673 674 // Verify that the keys have different outcomes due to their different labels 675 assert_ne!(app_keys, handshake_keys); 676 677 // Verify that the keys at each generation are different 678 assert_ne!(handshake_keys[0], handshake_keys[1]); 679 } 680 } 681 682 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_get_key()683 async fn test_get_key() { 684 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 685 let provider = test_cipher_suite_provider(cipher_suite); 686 687 let mut ratchet = SecretKeyRatchet::new( 688 &test_cipher_suite_provider(cipher_suite), 689 &vec![0u8; provider.kdf_extract_size()], 690 KeyType::Application, 691 ) 692 .await 693 .unwrap(); 694 695 let mut ratchet_clone = ratchet.clone(); 696 697 // This will generate keys 0 and 1 in ratchet_clone 698 let _ = ratchet_clone.next_message_key(&provider).await.unwrap(); 699 let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap(); 700 701 // Going back in time should result in an error 702 let res = ratchet_clone.get_message_key(&provider, 0).await; 703 assert!(res.is_err()); 704 705 // Calling get key should be the same as calling next until hitting the desired generation 706 let second_key = ratchet 707 .get_message_key(&provider, ratchet_clone.generation - 1) 708 .await 709 .unwrap(); 710 711 assert_eq!(clone_2, second_key) 712 } 713 } 714 715 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_ratchet()716 async fn test_secret_ratchet() { 717 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 718 let provider = test_cipher_suite_provider(cipher_suite); 719 720 let mut ratchet = SecretKeyRatchet::new( 721 &provider, 722 &vec![0u8; provider.kdf_extract_size()], 723 KeyType::Application, 724 ) 725 .await 726 .unwrap(); 727 728 let original_secret = ratchet.secret.clone(); 729 let _ = ratchet.next_message_key(&provider).await.unwrap(); 730 let new_secret = ratchet.secret; 731 assert_ne!(original_secret, new_secret) 732 } 733 } 734 735 #[cfg(feature = "out_of_order")] 736 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_out_of_order_keys()737 async fn test_out_of_order_keys() { 738 let cipher_suite = TEST_CIPHER_SUITE; 739 let provider = test_cipher_suite_provider(cipher_suite); 740 741 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) 742 .await 743 .unwrap(); 744 let mut ratchet_clone = ratchet.clone(); 745 746 // Ask for all the keys in order from the original ratchet 747 let mut ordered_keys = Vec::<MessageKeyData>::new(); 748 749 for i in 0..=MAX_RATCHET_BACK_HISTORY { 750 ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap()); 751 } 752 753 // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone 754 let last_key = ratchet_clone 755 .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY) 756 .await 757 .unwrap(); 758 759 assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]); 760 761 // Get all the other keys 762 let mut back_history_keys = Vec::<MessageKeyData>::new(); 763 764 for i in 0..MAX_RATCHET_BACK_HISTORY - 1 { 765 back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap()); 766 } 767 768 assert_eq!( 769 back_history_keys, 770 ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1] 771 ); 772 } 773 774 #[cfg(not(feature = "out_of_order"))] 775 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] out_of_order_keys_should_throw_error()776 async fn out_of_order_keys_should_throw_error() { 777 let cipher_suite = TEST_CIPHER_SUITE; 778 let provider = test_cipher_suite_provider(cipher_suite); 779 780 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) 781 .await 782 .unwrap(); 783 784 ratchet.get_message_key(&provider, 10).await.unwrap(); 785 let res = ratchet.get_message_key(&provider, 9).await; 786 assert_matches!(res, Err(MlsError::KeyMissing(9))) 787 } 788 789 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_too_out_of_order()790 async fn test_too_out_of_order() { 791 let cipher_suite = TEST_CIPHER_SUITE; 792 let provider = test_cipher_suite_provider(cipher_suite); 793 794 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) 795 .await 796 .unwrap(); 797 798 let res = ratchet 799 .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1) 800 .await; 801 802 let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1; 803 804 assert_matches!( 805 res, 806 Err(MlsError::InvalidFutureGeneration(invalid)) 807 if invalid == invalid_generation 808 ) 809 } 810 811 #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] 812 struct Ratchet { 813 application_keys: Vec<Vec<u8>>, 814 handshake_keys: Vec<Vec<u8>>, 815 } 816 817 #[derive(Debug, serde::Serialize, serde::Deserialize)] 818 struct TestCase { 819 cipher_suite: u16, 820 #[serde(with = "hex::serde")] 821 encryption_secret: Vec<u8>, 822 ratchets: Vec<Ratchet>, 823 } 824 825 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_ratchet_data( secret_tree: &mut SecretTree<NodeIndex>, cipher_suite: CipherSuite, ) -> Vec<Ratchet>826 async fn get_ratchet_data( 827 secret_tree: &mut SecretTree<NodeIndex>, 828 cipher_suite: CipherSuite, 829 ) -> Vec<Ratchet> { 830 let provider = test_cipher_suite_provider(cipher_suite); 831 let mut ratchet_data = Vec::new(); 832 833 for index in 0..16 { 834 let mut ratchets = secret_tree 835 .take_leaf_ratchet(&provider, &(index * 2)) 836 .await 837 .unwrap(); 838 839 let mut application_keys = Vec::new(); 840 841 for _ in 0..20 { 842 let key = ratchets 843 .handshake 844 .next_message_key(&provider) 845 .await 846 .unwrap() 847 .mls_encode_to_vec() 848 .unwrap(); 849 850 application_keys.push(key); 851 } 852 853 let mut handshake_keys = Vec::new(); 854 855 for _ in 0..20 { 856 let key = ratchets 857 .handshake 858 .next_message_key(&provider) 859 .await 860 .unwrap() 861 .mls_encode_to_vec() 862 .unwrap(); 863 864 handshake_keys.push(key); 865 } 866 867 ratchet_data.push(Ratchet { 868 application_keys, 869 handshake_keys, 870 }); 871 } 872 873 ratchet_data 874 } 875 876 #[cfg(not(mls_build_async))] 877 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<TestCase>878 fn generate_test_vector() -> Vec<TestCase> { 879 CipherSuite::all() 880 .map(|cipher_suite| { 881 let provider = test_cipher_suite_provider(cipher_suite); 882 let encryption_secret = random_bytes(provider.kdf_extract_size()); 883 884 let mut secret_tree = 885 SecretTree::new(16, Zeroizing::new(encryption_secret.clone())); 886 887 TestCase { 888 cipher_suite: cipher_suite.into(), 889 encryption_secret, 890 ratchets: get_ratchet_data(&mut secret_tree, cipher_suite), 891 } 892 }) 893 .collect() 894 } 895 896 #[cfg(mls_build_async)] generate_test_vector() -> Vec<TestCase>897 fn generate_test_vector() -> Vec<TestCase> { 898 panic!("Tests cannot be generated in async mode"); 899 } 900 901 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_tree_test_vectors()902 async fn test_secret_tree_test_vectors() { 903 let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector()); 904 905 for case in test_cases { 906 let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else { 907 continue; 908 }; 909 910 let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret)); 911 let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await; 912 913 assert_eq!(ratchet_data, case.ratchets); 914 } 915 } 916 } 917 918 #[cfg(all(test, feature = "rfc_compliant", feature = "std"))] 919 mod interop_tests { 920 #[cfg(not(mls_build_async))] 921 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider}; 922 use zeroize::Zeroizing; 923 924 use crate::{ 925 crypto::test_utils::try_test_cipher_suite_provider, 926 group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType}, 927 }; 928 929 use super::SecretTree; 930 931 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] interop_test_vector()932 async fn interop_test_vector() { 933 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json 934 let test_cases = load_interop_test_cases(); 935 936 for case in test_cases { 937 let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else { 938 continue; 939 }; 940 941 case.sender_data.verify(&cs).await; 942 943 let mut tree = SecretTree::new( 944 case.leaves.len() as u32, 945 Zeroizing::new(case.encryption_secret), 946 ); 947 948 for (index, leaves) in case.leaves.iter().enumerate() { 949 for leaf in leaves.iter() { 950 let key = tree 951 .message_key_generation( 952 &cs, 953 (index as u32) * 2, 954 KeyType::Application, 955 leaf.generation, 956 ) 957 .await 958 .unwrap(); 959 960 assert_eq!(key.key.to_vec(), leaf.application_key); 961 assert_eq!(key.nonce.to_vec(), leaf.application_nonce); 962 963 let key = tree 964 .message_key_generation( 965 &cs, 966 (index as u32) * 2, 967 KeyType::Handshake, 968 leaf.generation, 969 ) 970 .await 971 .unwrap(); 972 973 assert_eq!(key.key.to_vec(), leaf.handshake_key); 974 assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce); 975 } 976 } 977 } 978 } 979 980 #[derive(Debug, serde::Serialize, serde::Deserialize)] 981 struct InteropTestCase { 982 cipher_suite: u16, 983 #[serde(with = "hex::serde")] 984 encryption_secret: Vec<u8>, 985 sender_data: InteropSenderData, 986 leaves: Vec<Vec<InteropLeaf>>, 987 } 988 989 #[derive(Debug, serde::Serialize, serde::Deserialize)] 990 struct InteropLeaf { 991 generation: u32, 992 #[serde(with = "hex::serde")] 993 application_key: Vec<u8>, 994 #[serde(with = "hex::serde")] 995 application_nonce: Vec<u8>, 996 #[serde(with = "hex::serde")] 997 handshake_key: Vec<u8>, 998 #[serde(with = "hex::serde")] 999 handshake_nonce: Vec<u8>, 1000 } 1001 load_interop_test_cases() -> Vec<InteropTestCase>1002 fn load_interop_test_cases() -> Vec<InteropTestCase> { 1003 load_test_case_json!(secret_tree_interop, generate_test_vector()) 1004 } 1005 1006 #[cfg(not(mls_build_async))] 1007 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<InteropTestCase>1008 fn generate_test_vector() -> Vec<InteropTestCase> { 1009 let mut test_cases = vec![]; 1010 1011 for cs in CipherSuite::all() { 1012 let Some(cs) = try_test_cipher_suite_provider(*cs) else { 1013 continue; 1014 }; 1015 1016 let gens = [0, 15]; 1017 let tree_sizes = [1, 8, 32]; 1018 1019 for n_leaves in tree_sizes { 1020 let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(); 1021 1022 let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone())); 1023 1024 let leaves = (0..n_leaves) 1025 .map(|leaf| { 1026 gens.into_iter() 1027 .map(|gen| { 1028 let index = leaf * 2u32; 1029 1030 let handshake_key = tree 1031 .message_key_generation(&cs, index, KeyType::Handshake, gen) 1032 .unwrap(); 1033 1034 let app_key = tree 1035 .message_key_generation(&cs, index, KeyType::Application, gen) 1036 .unwrap(); 1037 1038 InteropLeaf { 1039 generation: gen, 1040 application_key: app_key.key.to_vec(), 1041 application_nonce: app_key.nonce.to_vec(), 1042 handshake_key: handshake_key.key.to_vec(), 1043 handshake_nonce: handshake_key.nonce.to_vec(), 1044 } 1045 }) 1046 .collect() 1047 }) 1048 .collect(); 1049 1050 let case = InteropTestCase { 1051 cipher_suite: *cs.cipher_suite(), 1052 encryption_secret, 1053 sender_data: InteropSenderData::new(&cs), 1054 leaves, 1055 }; 1056 1057 test_cases.push(case); 1058 } 1059 } 1060 1061 test_cases 1062 } 1063 1064 #[cfg(mls_build_async)] generate_test_vector() -> Vec<InteropTestCase>1065 fn generate_test_vector() -> Vec<InteropTestCase> { 1066 panic!("Tests cannot be generated in async mode"); 1067 } 1068 } 1069