• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
9 
10 use itertools::Itertools;
11 
12 use crate::{
13     crypto::test_utils::try_test_cipher_suite_provider, identity::basic::BasicIdentityProvider,
14 };
15 
16 use super::{
17     node::NodeVec, test_utils::TreeWithSigners, tree_validator::TreeValidator, TreeKemPublic,
18 };
19 
20 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
21 struct ValidationTestCase {
22     pub cipher_suite: u16,
23 
24     #[serde(with = "hex::serde")]
25     pub tree: Vec<u8>,
26     #[serde(with = "hex::serde")]
27     pub group_id: Vec<u8>,
28     pub tree_hashes: Vec<TreeHash>,
29     pub resolutions: Vec<Vec<u32>>,
30 }
31 
32 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
33 struct TreeHash(#[serde(with = "hex::serde")] pub Vec<u8>);
34 
35 impl From<crate::tree_kem::tree_hash::TreeHash> for TreeHash {
36     #[cfg_attr(coverage_nightly, coverage(off))]
from(value: crate::tree_kem::tree_hash::TreeHash) -> Self37     fn from(value: crate::tree_kem::tree_hash::TreeHash) -> Self {
38         TreeHash(value.to_vec())
39     }
40 }
41 
42 impl ValidationTestCase {
43     #[cfg_attr(coverage_nightly, coverage(off))]
new<P: CipherSuiteProvider>(tree: TreeKemPublic, group_id: &[u8], cs: &P) -> Self44     fn new<P: CipherSuiteProvider>(tree: TreeKemPublic, group_id: &[u8], cs: &P) -> Self {
45         let tree_size = tree.total_leaf_count() * 2 - 1;
46 
47         assert!(
48             tree.tree_hashes.current.len() == tree_size as usize,
49             "hashes not initialized"
50         );
51 
52         let resolutions = (0..tree_size)
53             .map(
54                 #[cfg_attr(coverage_nightly, coverage(off))]
55                 |i| tree.nodes.get_resolution_index(i).unwrap(),
56             )
57             .collect();
58 
59         Self {
60             cipher_suite: cs.cipher_suite().into(),
61             tree: tree.nodes.mls_encode_to_vec().unwrap(),
62             tree_hashes: tree
63                 .tree_hashes
64                 .current
65                 .into_iter()
66                 .map(TreeHash::from)
67                 .collect(),
68             group_id: group_id.to_vec(),
69             resolutions,
70         }
71     }
72 }
73 
74 #[cfg(feature = "rfc_compliant")]
75 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
76 #[cfg_attr(coverage_nightly, coverage(off))]
validation()77 async fn validation() {
78     use crate::group::test_utils::get_test_group_context;
79 
80     #[cfg(mls_build_async)]
81     let test_cases: Vec<ValidationTestCase> = load_test_case_json!(
82         interop_tree_validation,
83         generate_validation_test_vector().await
84     );
85 
86     #[cfg(not(mls_build_async))]
87     let test_cases: Vec<ValidationTestCase> =
88         load_test_case_json!(interop_tree_validation, generate_validation_test_vector());
89 
90     for test_case in test_cases.into_iter() {
91         let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
92             continue;
93         };
94 
95         let mut tree = TreeKemPublic::import_node_data(
96             NodeVec::mls_decode(&mut &*test_case.tree).unwrap(),
97             &BasicIdentityProvider,
98             &Default::default(),
99         )
100         .await
101         .unwrap();
102 
103         let tree_hash = tree.tree_hash(&cs).await.unwrap();
104 
105         tree.tree_hashes
106             .current
107             .iter()
108             .zip_eq(test_case.tree_hashes.iter())
109             .for_each(|(l, r)| assert_eq!(**l, *r.0));
110 
111         test_case
112             .resolutions
113             .iter()
114             .enumerate()
115             .for_each(|(i, res)| {
116                 assert_eq!(&tree.nodes.get_resolution_index(i as u32).unwrap(), res)
117             });
118 
119         let mut context = get_test_group_context(1, test_case.cipher_suite.into()).await;
120         context.tree_hash = tree_hash;
121         context.group_id = test_case.group_id;
122 
123         TreeValidator::new(&cs, &context, &BasicIdentityProvider)
124             .validate(&mut tree)
125             .await
126             .unwrap();
127     }
128 }
129 
130 #[cfg(feature = "rfc_compliant")]
131 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
132 #[cfg_attr(coverage_nightly, coverage(off))]
generate_validation_test_vector() -> Vec<ValidationTestCase>133 async fn generate_validation_test_vector() -> Vec<ValidationTestCase> {
134     let mut test_cases = vec![];
135 
136     for cs in CipherSuite::all() {
137         let Some(cs) = try_test_cipher_suite_provider(*cs) else {
138             continue;
139         };
140 
141         let mut trees = vec![];
142 
143         // Generate trees with increasing complexity. Start: full complete trees
144         for n_leaves in [2, 4, 8, 32] {
145             trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
146         }
147 
148         // Internal blanks, no skipping : 8 leaves, 0 commits removing 2, 3 and adding new member
149         let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
150         tree.remove_member(2);
151         tree.remove_member(3);
152         tree.add_member("Bob", &cs).await;
153         tree.update_committer_path(0, &cs).await;
154         trees.push(tree);
155 
156         // Blanks at the end, no skipping
157         for n_leaves in [3, 5, 7, 33] {
158             trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
159         }
160 
161         // Internal blanks, with skipping : 8 leaves, 0 commits removing 1, 2, 3
162         let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
163         [1, 2, 3].into_iter().for_each(
164             #[cfg_attr(coverage_nightly, coverage(off))]
165             |i| tree.remove_member(i),
166         );
167         tree.update_committer_path(0, &cs).await;
168         trees.push(tree);
169 
170         // Blanks at the end, with skipping
171         for n_leaves in [6, 34] {
172             trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
173         }
174 
175         // Unmerged leaves, no skipping : 7 leaves; 0 commits adding a member
176         let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
177         tree.add_member("Bob", &cs).await;
178         tree.update_committer_path(0, &cs).await;
179         trees.push(tree);
180 
181         // Unmerged leaves, with skipping : figure 20 in the RFC
182         let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
183         tree.remove_member(5);
184         tree.update_committer_path(0, &cs).await;
185         tree.update_committer_path(4, &cs).await;
186         tree.add_member("Bob", &cs).await;
187         tree.tree.tree_hashes.current = vec![];
188         tree.tree.tree_hash(&cs).await.unwrap();
189         trees.push(tree);
190 
191         // Generate tests
192         trees.into_iter().for_each(
193             #[cfg_attr(coverage_nightly, coverage(off))]
194             |tree| test_cases.push(ValidationTestCase::new(tree.tree, &tree.group_id, &cs)),
195         );
196     }
197 
198     test_cases
199 }
200