1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
9
10 use itertools::Itertools;
11
12 use crate::{
13 crypto::test_utils::try_test_cipher_suite_provider, identity::basic::BasicIdentityProvider,
14 };
15
16 use super::{
17 node::NodeVec, test_utils::TreeWithSigners, tree_validator::TreeValidator, TreeKemPublic,
18 };
19
20 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
21 struct ValidationTestCase {
22 pub cipher_suite: u16,
23
24 #[serde(with = "hex::serde")]
25 pub tree: Vec<u8>,
26 #[serde(with = "hex::serde")]
27 pub group_id: Vec<u8>,
28 pub tree_hashes: Vec<TreeHash>,
29 pub resolutions: Vec<Vec<u32>>,
30 }
31
32 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
33 struct TreeHash(#[serde(with = "hex::serde")] pub Vec<u8>);
34
35 impl From<crate::tree_kem::tree_hash::TreeHash> for TreeHash {
36 #[cfg_attr(coverage_nightly, coverage(off))]
from(value: crate::tree_kem::tree_hash::TreeHash) -> Self37 fn from(value: crate::tree_kem::tree_hash::TreeHash) -> Self {
38 TreeHash(value.to_vec())
39 }
40 }
41
42 impl ValidationTestCase {
43 #[cfg_attr(coverage_nightly, coverage(off))]
new<P: CipherSuiteProvider>(tree: TreeKemPublic, group_id: &[u8], cs: &P) -> Self44 fn new<P: CipherSuiteProvider>(tree: TreeKemPublic, group_id: &[u8], cs: &P) -> Self {
45 let tree_size = tree.total_leaf_count() * 2 - 1;
46
47 assert!(
48 tree.tree_hashes.current.len() == tree_size as usize,
49 "hashes not initialized"
50 );
51
52 let resolutions = (0..tree_size)
53 .map(
54 #[cfg_attr(coverage_nightly, coverage(off))]
55 |i| tree.nodes.get_resolution_index(i).unwrap(),
56 )
57 .collect();
58
59 Self {
60 cipher_suite: cs.cipher_suite().into(),
61 tree: tree.nodes.mls_encode_to_vec().unwrap(),
62 tree_hashes: tree
63 .tree_hashes
64 .current
65 .into_iter()
66 .map(TreeHash::from)
67 .collect(),
68 group_id: group_id.to_vec(),
69 resolutions,
70 }
71 }
72 }
73
74 #[cfg(feature = "rfc_compliant")]
75 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
76 #[cfg_attr(coverage_nightly, coverage(off))]
validation()77 async fn validation() {
78 use crate::group::test_utils::get_test_group_context;
79
80 #[cfg(mls_build_async)]
81 let test_cases: Vec<ValidationTestCase> = load_test_case_json!(
82 interop_tree_validation,
83 generate_validation_test_vector().await
84 );
85
86 #[cfg(not(mls_build_async))]
87 let test_cases: Vec<ValidationTestCase> =
88 load_test_case_json!(interop_tree_validation, generate_validation_test_vector());
89
90 for test_case in test_cases.into_iter() {
91 let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
92 continue;
93 };
94
95 let mut tree = TreeKemPublic::import_node_data(
96 NodeVec::mls_decode(&mut &*test_case.tree).unwrap(),
97 &BasicIdentityProvider,
98 &Default::default(),
99 )
100 .await
101 .unwrap();
102
103 let tree_hash = tree.tree_hash(&cs).await.unwrap();
104
105 tree.tree_hashes
106 .current
107 .iter()
108 .zip_eq(test_case.tree_hashes.iter())
109 .for_each(|(l, r)| assert_eq!(**l, *r.0));
110
111 test_case
112 .resolutions
113 .iter()
114 .enumerate()
115 .for_each(|(i, res)| {
116 assert_eq!(&tree.nodes.get_resolution_index(i as u32).unwrap(), res)
117 });
118
119 let mut context = get_test_group_context(1, test_case.cipher_suite.into()).await;
120 context.tree_hash = tree_hash;
121 context.group_id = test_case.group_id;
122
123 TreeValidator::new(&cs, &context, &BasicIdentityProvider)
124 .validate(&mut tree)
125 .await
126 .unwrap();
127 }
128 }
129
130 #[cfg(feature = "rfc_compliant")]
131 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
132 #[cfg_attr(coverage_nightly, coverage(off))]
generate_validation_test_vector() -> Vec<ValidationTestCase>133 async fn generate_validation_test_vector() -> Vec<ValidationTestCase> {
134 let mut test_cases = vec![];
135
136 for cs in CipherSuite::all() {
137 let Some(cs) = try_test_cipher_suite_provider(*cs) else {
138 continue;
139 };
140
141 let mut trees = vec![];
142
143 // Generate trees with increasing complexity. Start: full complete trees
144 for n_leaves in [2, 4, 8, 32] {
145 trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
146 }
147
148 // Internal blanks, no skipping : 8 leaves, 0 commits removing 2, 3 and adding new member
149 let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
150 tree.remove_member(2);
151 tree.remove_member(3);
152 tree.add_member("Bob", &cs).await;
153 tree.update_committer_path(0, &cs).await;
154 trees.push(tree);
155
156 // Blanks at the end, no skipping
157 for n_leaves in [3, 5, 7, 33] {
158 trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
159 }
160
161 // Internal blanks, with skipping : 8 leaves, 0 commits removing 1, 2, 3
162 let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
163 [1, 2, 3].into_iter().for_each(
164 #[cfg_attr(coverage_nightly, coverage(off))]
165 |i| tree.remove_member(i),
166 );
167 tree.update_committer_path(0, &cs).await;
168 trees.push(tree);
169
170 // Blanks at the end, with skipping
171 for n_leaves in [6, 34] {
172 trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
173 }
174
175 // Unmerged leaves, no skipping : 7 leaves; 0 commits adding a member
176 let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
177 tree.add_member("Bob", &cs).await;
178 tree.update_committer_path(0, &cs).await;
179 trees.push(tree);
180
181 // Unmerged leaves, with skipping : figure 20 in the RFC
182 let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
183 tree.remove_member(5);
184 tree.update_committer_path(0, &cs).await;
185 tree.update_committer_path(4, &cs).await;
186 tree.add_member("Bob", &cs).await;
187 tree.tree.tree_hashes.current = vec![];
188 tree.tree.tree_hash(&cs).await.unwrap();
189 trees.push(tree);
190
191 // Generate tests
192 trees.into_iter().for_each(
193 #[cfg_attr(coverage_nightly, coverage(off))]
194 |tree| test_cases.push(ValidationTestCase::new(tree.tree, &tree.group_id, &cs)),
195 );
196 }
197
198 test_cases
199 }
200