• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::{
7     fmt::{self, Debug},
8     ops::{Deref, DerefMut},
9 };
10 
11 use zeroize::Zeroizing;
12 
13 use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider};
14 
15 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
16 use mls_rs_core::error::IntoAnyError;
17 
18 use super::key_schedule::kdf_expand_with_label;
19 
20 pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
21 
22 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
23 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24 #[repr(u8)]
25 enum SecretTreeNode {
26     Secret(TreeSecret) = 0u8,
27     Ratchet(SecretRatchets) = 1u8,
28 }
29 
30 impl SecretTreeNode {
into_secret(self) -> Option<TreeSecret>31     fn into_secret(self) -> Option<TreeSecret> {
32         if let SecretTreeNode::Secret(secret) = self {
33             Some(secret)
34         } else {
35             None
36         }
37     }
38 }
39 
40 #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
41 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42 struct TreeSecret(
43     #[mls_codec(with = "mls_rs_codec::byte_vec")]
44     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
45     Zeroizing<Vec<u8>>,
46 );
47 
48 impl Debug for TreeSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result49     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50         mls_rs_core::debug::pretty_bytes(&self.0)
51             .named("TreeSecret")
52             .fmt(f)
53     }
54 }
55 
56 impl Deref for TreeSecret {
57     type Target = Vec<u8>;
58 
deref(&self) -> &Self::Target59     fn deref(&self) -> &Self::Target {
60         &self.0
61     }
62 }
63 
64 impl DerefMut for TreeSecret {
deref_mut(&mut self) -> &mut Self::Target65     fn deref_mut(&mut self) -> &mut Self::Target {
66         &mut self.0
67     }
68 }
69 
70 impl AsRef<[u8]> for TreeSecret {
as_ref(&self) -> &[u8]71     fn as_ref(&self) -> &[u8] {
72         &self.0
73     }
74 }
75 
76 impl From<Vec<u8>> for TreeSecret {
from(vec: Vec<u8>) -> Self77     fn from(vec: Vec<u8>) -> Self {
78         TreeSecret(Zeroizing::new(vec))
79     }
80 }
81 
82 impl From<Zeroizing<Vec<u8>>> for TreeSecret {
from(vec: Zeroizing<Vec<u8>>) -> Self83     fn from(vec: Zeroizing<Vec<u8>>) -> Self {
84         TreeSecret(vec)
85     }
86 }
87 
88 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
89 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
90 struct TreeSecretsVec<T: TreeIndex> {
91     inner: LargeMap<T, SecretTreeNode>,
92 }
93 
94 impl<T: TreeIndex> TreeSecretsVec<T> {
set_node(&mut self, index: T, value: SecretTreeNode)95     fn set_node(&mut self, index: T, value: SecretTreeNode) {
96         self.inner.insert(index, value);
97     }
98 
take_node(&mut self, index: &T) -> Option<SecretTreeNode>99     fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
100         self.inner.remove(index)
101     }
102 }
103 
104 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
105 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
106 pub struct SecretTree<T: TreeIndex> {
107     known_secrets: TreeSecretsVec<T>,
108     leaf_count: T,
109 }
110 
111 impl<T: TreeIndex> SecretTree<T> {
empty() -> SecretTree<T>112     pub(crate) fn empty() -> SecretTree<T> {
113         SecretTree {
114             known_secrets: Default::default(),
115             leaf_count: T::zero(),
116         }
117     }
118 }
119 
120 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
121 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
122 pub struct SecretRatchets {
123     pub application: SecretKeyRatchet,
124     pub handshake: SecretKeyRatchet,
125 }
126 
127 impl SecretRatchets {
128     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>129     pub async fn message_key_generation<P: CipherSuiteProvider>(
130         &mut self,
131         cipher_suite_provider: &P,
132         generation: u32,
133         key_type: KeyType,
134     ) -> Result<MessageKeyData, MlsError> {
135         match key_type {
136             KeyType::Handshake => {
137                 self.handshake
138                     .get_message_key(cipher_suite_provider, generation)
139                     .await
140             }
141             KeyType::Application => {
142                 self.application
143                     .get_message_key(cipher_suite_provider, generation)
144                     .await
145             }
146         }
147     }
148 
149     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>150     pub async fn next_message_key<P: CipherSuiteProvider>(
151         &mut self,
152         cipher_suite: &P,
153         key_type: KeyType,
154     ) -> Result<MessageKeyData, MlsError> {
155         match key_type {
156             KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
157             KeyType::Application => self.application.next_message_key(cipher_suite).await,
158         }
159     }
160 }
161 
162 impl<T: TreeIndex> SecretTree<T> {
new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T>163     pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
164         let mut known_secrets = TreeSecretsVec::default();
165 
166         let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
167         known_secrets.set_node(leaf_count.root(), root_secret);
168 
169         Self {
170             known_secrets,
171             leaf_count,
172         }
173     }
174 
175     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
consume_node<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, index: &T, ) -> Result<(), MlsError>176     async fn consume_node<P: CipherSuiteProvider>(
177         &mut self,
178         cipher_suite_provider: &P,
179         index: &T,
180     ) -> Result<(), MlsError> {
181         let node = self.known_secrets.take_node(index);
182 
183         if let Some(secret) = node.and_then(|n| n.into_secret()) {
184             let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
185             let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
186 
187             let left_secret =
188                 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
189                     .await?;
190 
191             let right_secret =
192                 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
193                     .await?;
194 
195             self.known_secrets
196                 .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
197 
198             self.known_secrets
199                 .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
200         }
201 
202         Ok(())
203     }
204 
205     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
take_leaf_ratchet<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: &T, ) -> Result<SecretRatchets, MlsError>206     async fn take_leaf_ratchet<P: CipherSuiteProvider>(
207         &mut self,
208         cipher_suite: &P,
209         leaf_index: &T,
210     ) -> Result<SecretRatchets, MlsError> {
211         let node_index = leaf_index;
212 
213         let node = match self.known_secrets.take_node(node_index) {
214             Some(node) => node,
215             None => {
216                 // Start at the root node and work your way down consuming any intermediates needed
217                 for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
218                     self.consume_node(cipher_suite, &i.path).await?;
219                 }
220 
221                 self.known_secrets
222                     .take_node(node_index)
223                     .ok_or(MlsError::InvalidLeafConsumption)?
224             }
225         };
226 
227         Ok(match node {
228             SecretTreeNode::Ratchet(ratchet) => ratchet,
229             SecretTreeNode::Secret(secret) => SecretRatchets {
230                 application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
231                     .await?,
232                 handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
233             },
234         })
235     }
236 
237     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>238     pub async fn next_message_key<P: CipherSuiteProvider>(
239         &mut self,
240         cipher_suite: &P,
241         leaf_index: T,
242         key_type: KeyType,
243     ) -> Result<MessageKeyData, MlsError> {
244         let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
245         let res = ratchet.next_message_key(cipher_suite, key_type).await?;
246 
247         self.known_secrets
248             .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
249 
250         Ok(res)
251     }
252 
253     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, generation: u32, ) -> Result<MessageKeyData, MlsError>254     pub async fn message_key_generation<P: CipherSuiteProvider>(
255         &mut self,
256         cipher_suite: &P,
257         leaf_index: T,
258         key_type: KeyType,
259         generation: u32,
260     ) -> Result<MessageKeyData, MlsError> {
261         let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
262 
263         let res = ratchet
264             .message_key_generation(cipher_suite, generation, key_type)
265             .await?;
266 
267         self.known_secrets
268             .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
269 
270         Ok(res)
271     }
272 }
273 
274 #[derive(Clone, Copy)]
275 pub enum KeyType {
276     Handshake,
277     Application,
278 }
279 
280 #[cfg_attr(
281     all(feature = "ffi", not(test)),
282     safer_ffi_gen::ffi_type(clone, opaque)
283 )]
284 #[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
285 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
286 /// AEAD key derived by the MLS secret tree.
287 pub struct MessageKeyData {
288     #[mls_codec(with = "mls_rs_codec::byte_vec")]
289     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
290     pub(crate) nonce: Zeroizing<Vec<u8>>,
291     #[mls_codec(with = "mls_rs_codec::byte_vec")]
292     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
293     pub(crate) key: Zeroizing<Vec<u8>>,
294     pub(crate) generation: u32,
295 }
296 
297 impl Debug for MessageKeyData {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result298     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
299         f.debug_struct("MessageKeyData")
300             .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
301             .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
302             .field("generation", &self.generation)
303             .finish()
304     }
305 }
306 
307 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
308 impl MessageKeyData {
309     /// AEAD nonce.
310     #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
nonce(&self) -> &[u8]311     pub fn nonce(&self) -> &[u8] {
312         &self.nonce
313     }
314 
315     /// AEAD key.
316     #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
key(&self) -> &[u8]317     pub fn key(&self) -> &[u8] {
318         &self.key
319     }
320 
321     /// Generation of this key within the key schedule.
322     #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
generation(&self) -> u32323     pub fn generation(&self) -> u32 {
324         self.generation
325     }
326 }
327 
328 #[derive(Debug, Clone, PartialEq)]
329 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
330 pub struct SecretKeyRatchet {
331     secret: TreeSecret,
332     generation: u32,
333     #[cfg(feature = "out_of_order")]
334     history: LargeMap<u32, MessageKeyData>,
335 }
336 
337 impl MlsSize for SecretKeyRatchet {
mls_encoded_len(&self) -> usize338     fn mls_encoded_len(&self) -> usize {
339         let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
340             + self.generation.mls_encoded_len();
341 
342         #[cfg(feature = "out_of_order")]
343         return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
344         #[cfg(not(feature = "out_of_order"))]
345         return len;
346     }
347 }
348 
349 #[cfg(feature = "out_of_order")]
350 impl MlsEncode for SecretKeyRatchet {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>351     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
352         mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
353         self.generation.mls_encode(writer)?;
354         mls_rs_codec::iter::mls_encode(self.history.values(), writer)
355     }
356 }
357 
358 #[cfg(not(feature = "out_of_order"))]
359 impl MlsEncode for SecretKeyRatchet {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>360     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
361         mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
362         self.generation.mls_encode(writer)
363     }
364 }
365 
366 impl MlsDecode for SecretKeyRatchet {
mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>367     fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
368         Ok(Self {
369             secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
370             generation: u32::mls_decode(reader)?,
371             #[cfg(feature = "out_of_order")]
372             history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
373                 let mut items = LargeMap::default();
374 
375                 while !data.is_empty() {
376                     let item = MessageKeyData::mls_decode(data)?;
377                     items.insert(item.generation, item);
378                 }
379 
380                 Ok(items)
381             })?,
382         })
383     }
384 }
385 
386 impl SecretKeyRatchet {
387     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], key_type: KeyType, ) -> Result<Self, MlsError>388     async fn new<P: CipherSuiteProvider>(
389         cipher_suite_provider: &P,
390         secret: &[u8],
391         key_type: KeyType,
392     ) -> Result<Self, MlsError> {
393         let label = match key_type {
394             KeyType::Handshake => b"handshake".as_slice(),
395             KeyType::Application => b"application".as_slice(),
396         };
397 
398         let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
399             .await
400             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
401 
402         Ok(Self {
403             secret: TreeSecret::from(secret),
404             generation: 0,
405             #[cfg(feature = "out_of_order")]
406             history: Default::default(),
407         })
408     }
409 
410     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, ) -> Result<MessageKeyData, MlsError>411     async fn get_message_key<P: CipherSuiteProvider>(
412         &mut self,
413         cipher_suite_provider: &P,
414         generation: u32,
415     ) -> Result<MessageKeyData, MlsError> {
416         #[cfg(feature = "out_of_order")]
417         if generation < self.generation {
418             return self
419                 .history
420                 .remove_entry(&generation)
421                 .map(|(_, mk)| mk)
422                 .ok_or(MlsError::KeyMissing(generation));
423         }
424 
425         #[cfg(not(feature = "out_of_order"))]
426         if generation < self.generation {
427             return Err(MlsError::KeyMissing(generation));
428         }
429 
430         let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
431 
432         if generation > max_generation_allowed {
433             return Err(MlsError::InvalidFutureGeneration(generation));
434         }
435 
436         #[cfg(not(feature = "out_of_order"))]
437         while self.generation < generation {
438             self.next_message_key(cipher_suite_provider)?;
439         }
440 
441         #[cfg(feature = "out_of_order")]
442         while self.generation < generation {
443             let key_data = self.next_message_key(cipher_suite_provider).await?;
444             self.history.insert(key_data.generation, key_data);
445         }
446 
447         self.next_message_key(cipher_suite_provider).await
448     }
449 
450     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, ) -> Result<MessageKeyData, MlsError>451     async fn next_message_key<P: CipherSuiteProvider>(
452         &mut self,
453         cipher_suite_provider: &P,
454     ) -> Result<MessageKeyData, MlsError> {
455         let generation = self.generation;
456 
457         let key = MessageKeyData {
458             nonce: self
459                 .derive_secret(
460                     cipher_suite_provider,
461                     b"nonce",
462                     cipher_suite_provider.aead_nonce_size(),
463                 )
464                 .await?,
465             key: self
466                 .derive_secret(
467                     cipher_suite_provider,
468                     b"key",
469                     cipher_suite_provider.aead_key_size(),
470                 )
471                 .await?,
472             generation,
473         };
474 
475         self.secret = self
476             .derive_secret(
477                 cipher_suite_provider,
478                 b"secret",
479                 cipher_suite_provider.kdf_extract_size(),
480             )
481             .await?
482             .into();
483 
484         self.generation = generation + 1;
485 
486         Ok(key)
487     }
488 
489     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive_secret<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, label: &[u8], len: usize, ) -> Result<Zeroizing<Vec<u8>>, MlsError>490     async fn derive_secret<P: CipherSuiteProvider>(
491         &self,
492         cipher_suite_provider: &P,
493         label: &[u8],
494         len: usize,
495     ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
496         kdf_expand_with_label(
497             cipher_suite_provider,
498             self.secret.as_ref(),
499             label,
500             &self.generation.to_be_bytes(),
501             Some(len),
502         )
503         .await
504         .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
505     }
506 }
507 
508 #[cfg(test)]
509 pub(crate) mod test_utils {
510     use alloc::{string::String, vec::Vec};
511     use mls_rs_core::crypto::CipherSuiteProvider;
512     use zeroize::Zeroizing;
513 
514     use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
515 
516     use super::{KeyType, SecretKeyRatchet, SecretTree};
517 
get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T>518     pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
519         SecretTree::new(leaf_count, Zeroizing::new(secret))
520     }
521 
522     impl SecretTree<u32> {
get_root_secret(&self) -> Vec<u8>523         pub(crate) fn get_root_secret(&self) -> Vec<u8> {
524             self.known_secrets
525                 .clone()
526                 .take_node(&self.leaf_count.root())
527                 .unwrap()
528                 .into_secret()
529                 .unwrap()
530                 .to_vec()
531         }
532     }
533 
534     #[derive(Debug, serde::Serialize, serde::Deserialize)]
535     pub struct RatchetInteropTestCase {
536         #[serde(with = "hex::serde")]
537         secret: Vec<u8>,
538         label: String,
539         generation: u32,
540         length: usize,
541         #[serde(with = "hex::serde")]
542         out: Vec<u8>,
543     }
544 
545     #[derive(Debug, serde::Serialize, serde::Deserialize)]
546     pub struct InteropTestCase {
547         cipher_suite: u16,
548         derive_tree_secret: RatchetInteropTestCase,
549     }
550 
551     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()552     async fn test_basic_crypto_test_vectors() {
553         let test_cases: Vec<InteropTestCase> =
554             load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
555 
556         for test_case in test_cases {
557             if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
558                 test_case.derive_tree_secret.verify(&cs).await
559             }
560         }
561     }
562 
563     impl RatchetInteropTestCase {
564         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify<P: CipherSuiteProvider>(&self, cs: &P)565         pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
566             let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
567                 .await
568                 .unwrap();
569 
570             ratchet.secret = self.secret.clone().into();
571             ratchet.generation = self.generation;
572 
573             let computed = ratchet
574                 .derive_secret(cs, self.label.as_bytes(), self.length)
575                 .await
576                 .unwrap();
577 
578             assert_eq!(&computed.to_vec(), &self.out);
579         }
580     }
581 }
582 
583 #[cfg(test)]
584 mod tests {
585     use alloc::vec;
586 
587     use crate::{
588         cipher_suite::CipherSuite,
589         client::test_utils::TEST_CIPHER_SUITE,
590         crypto::test_utils::{
591             test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
592         },
593         tree_kem::node::NodeIndex,
594     };
595 
596     #[cfg(not(mls_build_async))]
597     use crate::group::test_utils::random_bytes;
598 
599     use super::{test_utils::get_test_tree, *};
600 
601     use assert_matches::assert_matches;
602 
603     #[cfg(target_arch = "wasm32")]
604     use wasm_bindgen_test::wasm_bindgen_test as test;
605 
606     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_tree()607     async fn test_secret_tree() {
608         test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
609         test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
610     }
611 
612     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_secret_tree_custom<T: TreeIndex>( leaf_count: T, leaves_to_check: Vec<T>, all_deleted: bool, )613     async fn test_secret_tree_custom<T: TreeIndex>(
614         leaf_count: T,
615         leaves_to_check: Vec<T>,
616         all_deleted: bool,
617     ) {
618         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
619             let cs_provider = test_cipher_suite_provider(cipher_suite);
620 
621             let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
622             let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
623 
624             let mut secrets = Vec::<SecretRatchets>::new();
625 
626             for i in &leaves_to_check {
627                 let secret = test_tree
628                     .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
629                     .await
630                     .unwrap();
631 
632                 secrets.push(secret);
633             }
634 
635             // Verify the tree is now completely empty
636             assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
637 
638             // Verify that all the secrets are unique
639             let count = secrets.len();
640             secrets.dedup();
641             assert_eq!(count, secrets.len());
642         }
643     }
644 
645     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_key_ratchet()646     async fn test_secret_key_ratchet() {
647         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
648             let provider = test_cipher_suite_provider(cipher_suite);
649 
650             let mut app_ratchet = SecretKeyRatchet::new(
651                 &provider,
652                 &vec![0u8; provider.kdf_extract_size()],
653                 KeyType::Application,
654             )
655             .await
656             .unwrap();
657 
658             let mut handshake_ratchet = SecretKeyRatchet::new(
659                 &provider,
660                 &vec![0u8; provider.kdf_extract_size()],
661                 KeyType::Handshake,
662             )
663             .await
664             .unwrap();
665 
666             let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
667             let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
668             let app_keys = vec![app_key_one, app_key_two];
669 
670             let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
671             let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
672             let handshake_keys = vec![handshake_key_one, handshake_key_two];
673 
674             // Verify that the keys have different outcomes due to their different labels
675             assert_ne!(app_keys, handshake_keys);
676 
677             // Verify that the keys at each generation are different
678             assert_ne!(handshake_keys[0], handshake_keys[1]);
679         }
680     }
681 
682     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_get_key()683     async fn test_get_key() {
684         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
685             let provider = test_cipher_suite_provider(cipher_suite);
686 
687             let mut ratchet = SecretKeyRatchet::new(
688                 &test_cipher_suite_provider(cipher_suite),
689                 &vec![0u8; provider.kdf_extract_size()],
690                 KeyType::Application,
691             )
692             .await
693             .unwrap();
694 
695             let mut ratchet_clone = ratchet.clone();
696 
697             // This will generate keys 0 and 1 in ratchet_clone
698             let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
699             let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
700 
701             // Going back in time should result in an error
702             let res = ratchet_clone.get_message_key(&provider, 0).await;
703             assert!(res.is_err());
704 
705             // Calling get key should be the same as calling next until hitting the desired generation
706             let second_key = ratchet
707                 .get_message_key(&provider, ratchet_clone.generation - 1)
708                 .await
709                 .unwrap();
710 
711             assert_eq!(clone_2, second_key)
712         }
713     }
714 
715     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_ratchet()716     async fn test_secret_ratchet() {
717         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
718             let provider = test_cipher_suite_provider(cipher_suite);
719 
720             let mut ratchet = SecretKeyRatchet::new(
721                 &provider,
722                 &vec![0u8; provider.kdf_extract_size()],
723                 KeyType::Application,
724             )
725             .await
726             .unwrap();
727 
728             let original_secret = ratchet.secret.clone();
729             let _ = ratchet.next_message_key(&provider).await.unwrap();
730             let new_secret = ratchet.secret;
731             assert_ne!(original_secret, new_secret)
732         }
733     }
734 
735     #[cfg(feature = "out_of_order")]
736     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_out_of_order_keys()737     async fn test_out_of_order_keys() {
738         let cipher_suite = TEST_CIPHER_SUITE;
739         let provider = test_cipher_suite_provider(cipher_suite);
740 
741         let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
742             .await
743             .unwrap();
744         let mut ratchet_clone = ratchet.clone();
745 
746         // Ask for all the keys in order from the original ratchet
747         let mut ordered_keys = Vec::<MessageKeyData>::new();
748 
749         for i in 0..=MAX_RATCHET_BACK_HISTORY {
750             ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
751         }
752 
753         // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
754         let last_key = ratchet_clone
755             .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
756             .await
757             .unwrap();
758 
759         assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
760 
761         // Get all the other keys
762         let mut back_history_keys = Vec::<MessageKeyData>::new();
763 
764         for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
765             back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
766         }
767 
768         assert_eq!(
769             back_history_keys,
770             ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
771         );
772     }
773 
774     #[cfg(not(feature = "out_of_order"))]
775     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
out_of_order_keys_should_throw_error()776     async fn out_of_order_keys_should_throw_error() {
777         let cipher_suite = TEST_CIPHER_SUITE;
778         let provider = test_cipher_suite_provider(cipher_suite);
779 
780         let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
781             .await
782             .unwrap();
783 
784         ratchet.get_message_key(&provider, 10).await.unwrap();
785         let res = ratchet.get_message_key(&provider, 9).await;
786         assert_matches!(res, Err(MlsError::KeyMissing(9)))
787     }
788 
789     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_too_out_of_order()790     async fn test_too_out_of_order() {
791         let cipher_suite = TEST_CIPHER_SUITE;
792         let provider = test_cipher_suite_provider(cipher_suite);
793 
794         let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
795             .await
796             .unwrap();
797 
798         let res = ratchet
799             .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
800             .await;
801 
802         let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
803 
804         assert_matches!(
805             res,
806             Err(MlsError::InvalidFutureGeneration(invalid))
807             if invalid == invalid_generation
808         )
809     }
810 
811     #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
812     struct Ratchet {
813         application_keys: Vec<Vec<u8>>,
814         handshake_keys: Vec<Vec<u8>>,
815     }
816 
817     #[derive(Debug, serde::Serialize, serde::Deserialize)]
818     struct TestCase {
819         cipher_suite: u16,
820         #[serde(with = "hex::serde")]
821         encryption_secret: Vec<u8>,
822         ratchets: Vec<Ratchet>,
823     }
824 
825     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_ratchet_data( secret_tree: &mut SecretTree<NodeIndex>, cipher_suite: CipherSuite, ) -> Vec<Ratchet>826     async fn get_ratchet_data(
827         secret_tree: &mut SecretTree<NodeIndex>,
828         cipher_suite: CipherSuite,
829     ) -> Vec<Ratchet> {
830         let provider = test_cipher_suite_provider(cipher_suite);
831         let mut ratchet_data = Vec::new();
832 
833         for index in 0..16 {
834             let mut ratchets = secret_tree
835                 .take_leaf_ratchet(&provider, &(index * 2))
836                 .await
837                 .unwrap();
838 
839             let mut application_keys = Vec::new();
840 
841             for _ in 0..20 {
842                 let key = ratchets
843                     .handshake
844                     .next_message_key(&provider)
845                     .await
846                     .unwrap()
847                     .mls_encode_to_vec()
848                     .unwrap();
849 
850                 application_keys.push(key);
851             }
852 
853             let mut handshake_keys = Vec::new();
854 
855             for _ in 0..20 {
856                 let key = ratchets
857                     .handshake
858                     .next_message_key(&provider)
859                     .await
860                     .unwrap()
861                     .mls_encode_to_vec()
862                     .unwrap();
863 
864                 handshake_keys.push(key);
865             }
866 
867             ratchet_data.push(Ratchet {
868                 application_keys,
869                 handshake_keys,
870             });
871         }
872 
873         ratchet_data
874     }
875 
876     #[cfg(not(mls_build_async))]
877     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<TestCase>878     fn generate_test_vector() -> Vec<TestCase> {
879         CipherSuite::all()
880             .map(|cipher_suite| {
881                 let provider = test_cipher_suite_provider(cipher_suite);
882                 let encryption_secret = random_bytes(provider.kdf_extract_size());
883 
884                 let mut secret_tree =
885                     SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
886 
887                 TestCase {
888                     cipher_suite: cipher_suite.into(),
889                     encryption_secret,
890                     ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
891                 }
892             })
893             .collect()
894     }
895 
896     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<TestCase>897     fn generate_test_vector() -> Vec<TestCase> {
898         panic!("Tests cannot be generated in async mode");
899     }
900 
901     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_tree_test_vectors()902     async fn test_secret_tree_test_vectors() {
903         let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
904 
905         for case in test_cases {
906             let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
907                 continue;
908             };
909 
910             let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
911             let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
912 
913             assert_eq!(ratchet_data, case.ratchets);
914         }
915     }
916 }
917 
918 #[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
919 mod interop_tests {
920     #[cfg(not(mls_build_async))]
921     use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
922     use zeroize::Zeroizing;
923 
924     use crate::{
925         crypto::test_utils::try_test_cipher_suite_provider,
926         group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
927     };
928 
929     use super::SecretTree;
930 
931     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
interop_test_vector()932     async fn interop_test_vector() {
933         // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
934         let test_cases = load_interop_test_cases();
935 
936         for case in test_cases {
937             let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
938                 continue;
939             };
940 
941             case.sender_data.verify(&cs).await;
942 
943             let mut tree = SecretTree::new(
944                 case.leaves.len() as u32,
945                 Zeroizing::new(case.encryption_secret),
946             );
947 
948             for (index, leaves) in case.leaves.iter().enumerate() {
949                 for leaf in leaves.iter() {
950                     let key = tree
951                         .message_key_generation(
952                             &cs,
953                             (index as u32) * 2,
954                             KeyType::Application,
955                             leaf.generation,
956                         )
957                         .await
958                         .unwrap();
959 
960                     assert_eq!(key.key.to_vec(), leaf.application_key);
961                     assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
962 
963                     let key = tree
964                         .message_key_generation(
965                             &cs,
966                             (index as u32) * 2,
967                             KeyType::Handshake,
968                             leaf.generation,
969                         )
970                         .await
971                         .unwrap();
972 
973                     assert_eq!(key.key.to_vec(), leaf.handshake_key);
974                     assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
975                 }
976             }
977         }
978     }
979 
980     #[derive(Debug, serde::Serialize, serde::Deserialize)]
981     struct InteropTestCase {
982         cipher_suite: u16,
983         #[serde(with = "hex::serde")]
984         encryption_secret: Vec<u8>,
985         sender_data: InteropSenderData,
986         leaves: Vec<Vec<InteropLeaf>>,
987     }
988 
989     #[derive(Debug, serde::Serialize, serde::Deserialize)]
990     struct InteropLeaf {
991         generation: u32,
992         #[serde(with = "hex::serde")]
993         application_key: Vec<u8>,
994         #[serde(with = "hex::serde")]
995         application_nonce: Vec<u8>,
996         #[serde(with = "hex::serde")]
997         handshake_key: Vec<u8>,
998         #[serde(with = "hex::serde")]
999         handshake_nonce: Vec<u8>,
1000     }
1001 
load_interop_test_cases() -> Vec<InteropTestCase>1002     fn load_interop_test_cases() -> Vec<InteropTestCase> {
1003         load_test_case_json!(secret_tree_interop, generate_test_vector())
1004     }
1005 
1006     #[cfg(not(mls_build_async))]
1007     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<InteropTestCase>1008     fn generate_test_vector() -> Vec<InteropTestCase> {
1009         let mut test_cases = vec![];
1010 
1011         for cs in CipherSuite::all() {
1012             let Some(cs) = try_test_cipher_suite_provider(*cs) else {
1013                 continue;
1014             };
1015 
1016             let gens = [0, 15];
1017             let tree_sizes = [1, 8, 32];
1018 
1019             for n_leaves in tree_sizes {
1020                 let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
1021 
1022                 let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
1023 
1024                 let leaves = (0..n_leaves)
1025                     .map(|leaf| {
1026                         gens.into_iter()
1027                             .map(|gen| {
1028                                 let index = leaf * 2u32;
1029 
1030                                 let handshake_key = tree
1031                                     .message_key_generation(&cs, index, KeyType::Handshake, gen)
1032                                     .unwrap();
1033 
1034                                 let app_key = tree
1035                                     .message_key_generation(&cs, index, KeyType::Application, gen)
1036                                     .unwrap();
1037 
1038                                 InteropLeaf {
1039                                     generation: gen,
1040                                     application_key: app_key.key.to_vec(),
1041                                     application_nonce: app_key.nonce.to_vec(),
1042                                     handshake_key: handshake_key.key.to_vec(),
1043                                     handshake_nonce: handshake_key.nonce.to_vec(),
1044                                 }
1045                             })
1046                             .collect()
1047                     })
1048                     .collect();
1049 
1050                 let case = InteropTestCase {
1051                     cipher_suite: *cs.cipher_suite(),
1052                     encryption_secret,
1053                     sender_data: InteropSenderData::new(&cs),
1054                     leaves,
1055                 };
1056 
1057                 test_cases.push(case);
1058             }
1059         }
1060 
1061         test_cases
1062     }
1063 
1064     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<InteropTestCase>1065     fn generate_test_vector() -> Vec<InteropTestCase> {
1066         panic!("Tests cannot be generated in async mode");
1067     }
1068 }
1069