• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use super::leaf_node::LeafNode;
6 use super::node::{LeafIndex, NodeVec};
7 use super::tree_math::BfsIterTopDown;
8 use crate::client::MlsError;
9 use crate::crypto::CipherSuiteProvider;
10 use crate::tree_kem::math as tree_math;
11 use crate::tree_kem::node::Parent;
12 use crate::tree_kem::TreeKemPublic;
13 use alloc::collections::VecDeque;
14 use alloc::vec;
15 use alloc::vec::Vec;
16 use core::fmt::{self, Debug};
17 use itertools::Itertools;
18 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
19 use mls_rs_core::error::IntoAnyError;
20 use tree_math::TreeIndex;
21 
22 use core::ops::Deref;
23 
24 #[derive(Clone, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
25 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26 pub(crate) struct TreeHash(
27     #[mls_codec(with = "mls_rs_codec::byte_vec")]
28     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
29     Vec<u8>,
30 );
31 
32 impl Debug for TreeHash {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result33     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34         mls_rs_core::debug::pretty_bytes(&self.0)
35             .named("TreeHash")
36             .fmt(f)
37     }
38 }
39 
40 impl Deref for TreeHash {
41     type Target = [u8];
42 
deref(&self) -> &Self::Target43     fn deref(&self) -> &Self::Target {
44         &self.0
45     }
46 }
47 
48 #[derive(Clone, Debug, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
49 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50 pub(crate) struct TreeHashes {
51     pub current: Vec<TreeHash>,
52 }
53 
54 #[derive(Debug, MlsSize, MlsEncode)]
55 struct LeafNodeHashInput<'a> {
56     leaf_index: LeafIndex,
57     leaf_node: Option<&'a LeafNode>,
58 }
59 
60 #[derive(Debug, MlsSize, MlsEncode)]
61 struct ParentNodeTreeHashInput<'a> {
62     parent_node: Option<&'a Parent>,
63     #[mls_codec(with = "mls_rs_codec::byte_vec")]
64     left_hash: &'a [u8],
65     #[mls_codec(with = "mls_rs_codec::byte_vec")]
66     right_hash: &'a [u8],
67 }
68 
69 #[derive(Debug, MlsSize, MlsEncode)]
70 #[repr(u8)]
71 enum TreeHashInput<'a> {
72     Leaf(LeafNodeHashInput<'a>) = 1u8,
73     Parent(ParentNodeTreeHashInput<'a>) = 2u8,
74 }
75 
76 impl TreeKemPublic {
77     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
78     #[inline(never)]
tree_hash<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, ) -> Result<Vec<u8>, MlsError>79     pub async fn tree_hash<P: CipherSuiteProvider>(
80         &mut self,
81         cipher_suite_provider: &P,
82     ) -> Result<Vec<u8>, MlsError> {
83         self.initialize_hashes(cipher_suite_provider).await?;
84         let root = self.total_leaf_count().root();
85         Ok(self.tree_hashes.current[root as usize].to_vec())
86     }
87 
88     // Update hashes after `committer` makes changes to the tree. `path_blank` is the
89     // list of leaves whose paths were blanked, i.e. updates and removes.
90     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_hashes<P: CipherSuiteProvider>( &mut self, updated_leaves: &[LeafIndex], cipher_suite_provider: &P, ) -> Result<(), MlsError>91     pub async fn update_hashes<P: CipherSuiteProvider>(
92         &mut self,
93         updated_leaves: &[LeafIndex],
94         cipher_suite_provider: &P,
95     ) -> Result<(), MlsError> {
96         let num_leaves = self.total_leaf_count();
97 
98         let trailing_blanks = (0..num_leaves)
99             .rev()
100             .map_while(|l| {
101                 self.tree_hashes
102                     .current
103                     .get(2 * l as usize)
104                     .is_none()
105                     .then_some(LeafIndex(l))
106             })
107             .collect::<Vec<_>>();
108 
109         // Update the current hashes for direct paths of all modified leaves.
110         tree_hash(
111             &mut self.tree_hashes.current,
112             &self.nodes,
113             Some([updated_leaves, &trailing_blanks].concat()),
114             &[],
115             num_leaves,
116             cipher_suite_provider,
117         )
118         .await?;
119 
120         Ok(())
121     }
122 
123     // Initialize all hashes after creating / importing a tree.
124     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
initialize_hashes<P>(&mut self, cipher_suite_provider: &P) -> Result<(), MlsError> where P: CipherSuiteProvider,125     async fn initialize_hashes<P>(&mut self, cipher_suite_provider: &P) -> Result<(), MlsError>
126     where
127         P: CipherSuiteProvider,
128     {
129         if self.tree_hashes.current.is_empty() {
130             let num_leaves = self.total_leaf_count();
131 
132             tree_hash(
133                 &mut self.tree_hashes.current,
134                 &self.nodes,
135                 None,
136                 &[],
137                 num_leaves,
138                 cipher_suite_provider,
139             )
140             .await?;
141         }
142 
143         Ok(())
144     }
145 
unmerged_in_subtree( &self, node_unmerged: u32, subtree_root: u32, ) -> Result<&[LeafIndex], MlsError>146     pub(crate) fn unmerged_in_subtree(
147         &self,
148         node_unmerged: u32,
149         subtree_root: u32,
150     ) -> Result<&[LeafIndex], MlsError> {
151         let unmerged = &self.nodes.borrow_as_parent(node_unmerged)?.unmerged_leaves;
152         let (left, right) = tree_math::subtree(subtree_root);
153         let mut start = 0;
154         while start < unmerged.len() && unmerged[start] < left {
155             start += 1;
156         }
157         let mut end = start;
158         while end < unmerged.len() && unmerged[end] < right {
159             end += 1;
160         }
161         Ok(&unmerged[start..end])
162     }
163 
different_unmerged(&self, ancestor: u32, descendant: u32) -> Result<bool, MlsError>164     fn different_unmerged(&self, ancestor: u32, descendant: u32) -> Result<bool, MlsError> {
165         Ok(!self.nodes.is_blank(ancestor)?
166             && !self.nodes.is_blank(descendant)?
167             && self.unmerged_in_subtree(ancestor, descendant)?
168                 != self.nodes.borrow_as_parent(descendant)?.unmerged_leaves)
169     }
170 
171     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
compute_original_hashes<P: CipherSuiteProvider>( &self, cipher_suite: &P, ) -> Result<Vec<TreeHash>, MlsError>172     pub(crate) async fn compute_original_hashes<P: CipherSuiteProvider>(
173         &self,
174         cipher_suite: &P,
175     ) -> Result<Vec<TreeHash>, MlsError> {
176         let num_leaves = self.nodes.total_leaf_count() as usize;
177         let root = (num_leaves as u32).root();
178 
179         // The value `filtered_sets[n]` is a list of all ancestors `a` of `n` s.t. we have to compute
180         // the tree hash of `n` with the unmerged leaves of `a` filtered out.
181         let mut filtered_sets = vec![vec![]; num_leaves * 2 - 1];
182         filtered_sets[root as usize].push(root);
183         let mut tree_hashes = vec![vec![]; num_leaves * 2 - 1];
184 
185         let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1);
186 
187         for n in bfs_iter {
188             let Some(ps) = (n as u32).parent_sibling(&(num_leaves as u32)) else {
189                 break;
190             };
191 
192             let p = ps.parent;
193 
194             // Clippy's suggestion `filtered_sets[n].clone_from(&filtered_sets[p as usize])` is wrong and does not compile
195             #[allow(clippy::assigning_clones)]
196             {
197                 filtered_sets[n] = filtered_sets[p as usize].clone();
198             }
199 
200             if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? {
201                 filtered_sets[n].push(p);
202 
203                 // Compute tree hash of `n` without unmerged leaves of `p`. This also computes the tree hash
204                 // for any descendants of `n` added to `filtered_sets` later via `clone`.
205                 let (start_leaf, end_leaf) = tree_math::subtree(n as u32);
206 
207                 tree_hash(
208                     &mut tree_hashes[p as usize],
209                     &self.nodes,
210                     Some((*start_leaf..*end_leaf).map(LeafIndex).collect_vec()),
211                     &self.nodes.borrow_as_parent(p)?.unmerged_leaves,
212                     num_leaves as u32,
213                     cipher_suite,
214                 )
215                 .await?;
216             }
217         }
218 
219         // Set the `original_hashes` based on the computed `hashes`.
220         let mut original_hashes = vec![TreeHash::default(); num_leaves * 2 - 1];
221 
222         // If root has unmerged leaves, we recompute it's original hash. Else, we can use the current hash.
223         let root_original = if !self.nodes.is_blank(root)? && !self.nodes.is_leaf(root) {
224             let root_unmerged = &self.nodes.borrow_as_parent(root)?.unmerged_leaves;
225 
226             if !root_unmerged.is_empty() {
227                 let mut hashes = vec![];
228 
229                 tree_hash(
230                     &mut hashes,
231                     &self.nodes,
232                     None,
233                     root_unmerged,
234                     num_leaves as u32,
235                     cipher_suite,
236                 )
237                 .await?;
238 
239                 Some(hashes)
240             } else {
241                 None
242             }
243         } else {
244             None
245         };
246 
247         for (i, hash) in original_hashes.iter_mut().enumerate() {
248             let a = filtered_sets[i].last().unwrap();
249             *hash = if self.nodes.is_blank(*a)? || a == &root {
250                 if let Some(root_original) = &root_original {
251                     root_original[i].clone()
252                 } else {
253                     self.tree_hashes.current[i].clone()
254                 }
255             } else {
256                 tree_hashes[*a as usize][i].clone()
257             }
258         }
259 
260         Ok(original_hashes)
261     }
262 }
263 
264 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
tree_hash<P: CipherSuiteProvider>( hashes: &mut Vec<TreeHash>, nodes: &NodeVec, leaves_to_update: Option<Vec<LeafIndex>>, filtered_leaves: &[LeafIndex], num_leaves: u32, cipher_suite_provider: &P, ) -> Result<(), MlsError>265 async fn tree_hash<P: CipherSuiteProvider>(
266     hashes: &mut Vec<TreeHash>,
267     nodes: &NodeVec,
268     leaves_to_update: Option<Vec<LeafIndex>>,
269     filtered_leaves: &[LeafIndex],
270     num_leaves: u32,
271     cipher_suite_provider: &P,
272 ) -> Result<(), MlsError> {
273     let leaves_to_update =
274         leaves_to_update.unwrap_or_else(|| (0..num_leaves).map(LeafIndex).collect::<Vec<_>>());
275 
276     // Resize the array in case the tree was extended or truncated
277     hashes.resize(num_leaves as usize * 2 - 1, TreeHash::default());
278 
279     let mut node_queue = VecDeque::with_capacity(leaves_to_update.len());
280 
281     for l in leaves_to_update.iter().filter(|l| ***l < num_leaves) {
282         let leaf = (!filtered_leaves.contains(l))
283             .then_some(nodes.borrow_as_leaf(*l).ok())
284             .flatten();
285 
286         hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?);
287 
288         if let Some(ps) = (2 * **l).parent_sibling(&num_leaves) {
289             node_queue.push_back(ps.parent);
290         }
291     }
292 
293     while let Some(n) = node_queue.pop_front() {
294         let hash = TreeHash(
295             hash_for_parent(
296                 nodes.borrow_as_parent(n).ok(),
297                 cipher_suite_provider,
298                 filtered_leaves,
299                 &hashes[n.left_unchecked() as usize],
300                 &hashes[n.right_unchecked() as usize],
301             )
302             .await?,
303         );
304 
305         hashes[n as usize] = hash;
306 
307         if let Some(ps) = n.parent_sibling(&num_leaves) {
308             node_queue.push_back(ps.parent);
309         }
310     }
311 
312     Ok(())
313 }
314 
315 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
hash_for_leaf<P: CipherSuiteProvider>( leaf_index: LeafIndex, leaf_node: Option<&LeafNode>, cipher_suite_provider: &P, ) -> Result<Vec<u8>, MlsError>316 async fn hash_for_leaf<P: CipherSuiteProvider>(
317     leaf_index: LeafIndex,
318     leaf_node: Option<&LeafNode>,
319     cipher_suite_provider: &P,
320 ) -> Result<Vec<u8>, MlsError> {
321     let input = TreeHashInput::Leaf(LeafNodeHashInput {
322         leaf_index,
323         leaf_node,
324     });
325 
326     cipher_suite_provider
327         .hash(&input.mls_encode_to_vec()?)
328         .await
329         .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
330 }
331 
332 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
hash_for_parent<P: CipherSuiteProvider>( parent_node: Option<&Parent>, cipher_suite_provider: &P, filtered: &[LeafIndex], left_hash: &[u8], right_hash: &[u8], ) -> Result<Vec<u8>, MlsError>333 async fn hash_for_parent<P: CipherSuiteProvider>(
334     parent_node: Option<&Parent>,
335     cipher_suite_provider: &P,
336     filtered: &[LeafIndex],
337     left_hash: &[u8],
338     right_hash: &[u8],
339 ) -> Result<Vec<u8>, MlsError> {
340     let mut parent_node = parent_node.cloned();
341 
342     if let Some(ref mut parent_node) = parent_node {
343         parent_node
344             .unmerged_leaves
345             .retain(|unmerged_index| !filtered.contains(unmerged_index));
346     }
347 
348     let input = TreeHashInput::Parent(ParentNodeTreeHashInput {
349         parent_node: parent_node.as_ref(),
350         left_hash,
351         right_hash,
352     });
353 
354     cipher_suite_provider
355         .hash(&input.mls_encode_to_vec()?)
356         .await
357         .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
358 }
359 
360 #[cfg(test)]
361 mod tests {
362     use mls_rs_codec::MlsDecode;
363 
364     use crate::{
365         cipher_suite::CipherSuite,
366         crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
367         identity::basic::BasicIdentityProvider,
368         tree_kem::{node::NodeVec, parent_hash::test_utils::get_test_tree_fig_12},
369     };
370 
371     use super::*;
372 
373     #[derive(serde::Deserialize, serde::Serialize)]
374     struct TestCase {
375         cipher_suite: u16,
376         #[serde(with = "hex::serde")]
377         tree_data: Vec<u8>,
378         #[serde(with = "hex::serde")]
379         tree_hash: Vec<u8>,
380     }
381 
382     impl TestCase {
383         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
384         #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestCase>385         async fn generate() -> Vec<TestCase> {
386             let mut test_cases = Vec::new();
387 
388             for cipher_suite in CipherSuite::all() {
389                 let mut tree = get_test_tree_fig_12(cipher_suite).await;
390 
391                 test_cases.push(TestCase {
392                     cipher_suite: cipher_suite.into(),
393                     tree_data: tree.nodes.mls_encode_to_vec().unwrap(),
394                     tree_hash: tree
395                         .tree_hash(&test_cipher_suite_provider(cipher_suite))
396                         .await
397                         .unwrap(),
398                 })
399             }
400 
401             test_cases
402         }
403     }
404 
405     #[cfg(mls_build_async)]
load_test_cases() -> Vec<TestCase>406     async fn load_test_cases() -> Vec<TestCase> {
407         load_test_case_json!(tree_hash, TestCase::generate().await)
408     }
409 
410     #[cfg(not(mls_build_async))]
load_test_cases() -> Vec<TestCase>411     fn load_test_cases() -> Vec<TestCase> {
412         load_test_case_json!(tree_hash, TestCase::generate())
413     }
414 
415     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_tree_hash()416     async fn test_tree_hash() {
417         let cases = load_test_cases().await;
418 
419         for one_case in cases {
420             let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
421                 continue;
422             };
423 
424             let mut tree = TreeKemPublic::import_node_data(
425                 NodeVec::mls_decode(&mut &*one_case.tree_data).unwrap(),
426                 &BasicIdentityProvider,
427                 &Default::default(),
428             )
429             .await
430             .unwrap();
431 
432             let calculated_hash = tree.tree_hash(&cs_provider).await.unwrap();
433 
434             assert_eq!(calculated_hash, one_case.tree_hash);
435         }
436     }
437 }
438