1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use super::*;
6 #[cfg(feature = "tree_index")]
7 use core::fmt::{self, Debug};
8
9 #[cfg(all(feature = "tree_index", feature = "custom_proposal"))]
10 use crate::group::proposal::ProposalType;
11
12 #[cfg(feature = "tree_index")]
13 use crate::{
14 identity::CredentialType,
15 map::{LargeMap, LargeMapEntry},
16 };
17
18 #[cfg(feature = "tree_index")]
19 use mls_rs_core::crypto::SignaturePublicKey;
20
21 #[cfg(all(feature = "tree_index", feature = "std"))]
22 use itertools::Itertools;
23
24 #[cfg(all(feature = "tree_index", not(feature = "std")))]
25 use alloc::collections::BTreeSet;
26
27 #[cfg(feature = "tree_index")]
28 use mls_rs_core::crypto::HpkePublicKey;
29
30 #[cfg(feature = "tree_index")]
31 #[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)]
32 pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
33
34 #[cfg(feature = "tree_index")]
35 impl Debug for Identifier {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 mls_rs_core::debug::pretty_bytes(&self.0)
38 .named("Identifier")
39 .fmt(f)
40 }
41 }
42
43 #[cfg(feature = "tree_index")]
44 #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
45 pub struct TreeIndex {
46 credential_signature_key: LargeMap<SignaturePublicKey, LeafIndex>,
47 hpke_key: LargeMap<HpkePublicKey, LeafIndex>,
48 identities: LargeMap<Identifier, LeafIndex>,
49 credential_type_counters: LargeMap<CredentialType, TypeCounter>,
50 #[cfg(feature = "custom_proposal")]
51 proposal_type_counter: LargeMap<ProposalType, u32>,
52 }
53
54 #[cfg(feature = "tree_index")]
55 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( tree_index: &mut TreeIndex, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>56 pub(super) async fn index_insert<I: IdentityProvider>(
57 tree_index: &mut TreeIndex,
58 new_leaf: &LeafNode,
59 new_leaf_idx: LeafIndex,
60 id_provider: &I,
61 extensions: &ExtensionList,
62 ) -> Result<(), MlsError> {
63 let new_id = id_provider
64 .identity(&new_leaf.signing_identity, extensions)
65 .await
66 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
67
68 tree_index.insert(new_leaf_idx, new_leaf, new_id)
69 }
70
71 #[cfg(not(feature = "tree_index"))]
72 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( nodes: &NodeVec, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>73 pub(super) async fn index_insert<I: IdentityProvider>(
74 nodes: &NodeVec,
75 new_leaf: &LeafNode,
76 new_leaf_idx: LeafIndex,
77 id_provider: &I,
78 extensions: &ExtensionList,
79 ) -> Result<(), MlsError> {
80 let new_id = id_provider
81 .identity(&new_leaf.signing_identity, extensions)
82 .await
83 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
84
85 for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) {
86 (new_leaf.public_key != leaf.public_key)
87 .then_some(())
88 .ok_or(MlsError::DuplicateLeafData(*i))?;
89
90 (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key)
91 .then_some(())
92 .ok_or(MlsError::DuplicateLeafData(*i))?;
93
94 let id = id_provider
95 .identity(&leaf.signing_identity, extensions)
96 .await
97 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
98
99 (new_id != id)
100 .then_some(())
101 .ok_or(MlsError::DuplicateLeafData(*i))?;
102
103 let cred_type = leaf.signing_identity.credential.credential_type();
104
105 new_leaf
106 .capabilities
107 .credentials
108 .contains(&cred_type)
109 .then_some(())
110 .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?;
111
112 let new_cred_type = new_leaf.signing_identity.credential.credential_type();
113
114 leaf.capabilities
115 .credentials
116 .contains(&new_cred_type)
117 .then_some(())
118 .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?;
119 }
120
121 Ok(())
122 }
123
124 #[cfg(feature = "tree_index")]
125 impl TreeIndex {
new() -> Self126 pub fn new() -> Self {
127 Default::default()
128 }
129
is_initialized(&self) -> bool130 pub fn is_initialized(&self) -> bool {
131 !self.identities.is_empty()
132 }
133
insert( &mut self, index: LeafIndex, leaf_node: &LeafNode, identity: Vec<u8>, ) -> Result<(), MlsError>134 fn insert(
135 &mut self,
136 index: LeafIndex,
137 leaf_node: &LeafNode,
138 identity: Vec<u8>,
139 ) -> Result<(), MlsError> {
140 let old_leaf_count = self.credential_signature_key.len();
141
142 let pub_key = leaf_node.signing_identity.signature_key.clone();
143 let credential_entry = self.credential_signature_key.entry(pub_key);
144
145 if let LargeMapEntry::Occupied(entry) = credential_entry {
146 return Err(MlsError::DuplicateLeafData(**entry.get()));
147 }
148
149 let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone());
150
151 if let LargeMapEntry::Occupied(entry) = hpke_entry {
152 return Err(MlsError::DuplicateLeafData(**entry.get()));
153 }
154
155 let identity_entry = self.identities.entry(Identifier(identity));
156 if let LargeMapEntry::Occupied(entry) = identity_entry {
157 return Err(MlsError::DuplicateLeafData(**entry.get()));
158 }
159
160 let in_use_cred_type_unsupported_by_new_leaf = self
161 .credential_type_counters
162 .iter()
163 .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0))
164 .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type));
165
166 if in_use_cred_type_unsupported_by_new_leaf.is_some() {
167 return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf);
168 }
169
170 let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
171
172 let cred_type_counters = self
173 .credential_type_counters
174 .entry(new_leaf_cred_type)
175 .or_default();
176
177 if cred_type_counters.supported != old_leaf_count as u32 {
178 return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported);
179 }
180
181 cred_type_counters.used += 1;
182
183 let credential_type_iter = leaf_node.capabilities.credentials.iter().copied();
184
185 #[cfg(feature = "std")]
186 let credential_type_iter = credential_type_iter.unique();
187
188 #[cfg(not(feature = "std"))]
189 let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
190
191 // Credential type counter updates
192 credential_type_iter.for_each(|cred_type| {
193 self.credential_type_counters
194 .entry(cred_type)
195 .or_default()
196 .supported += 1;
197 });
198
199 #[cfg(feature = "custom_proposal")]
200 {
201 let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied();
202
203 #[cfg(feature = "std")]
204 let proposal_type_iter = proposal_type_iter.unique();
205
206 #[cfg(not(feature = "std"))]
207 let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
208
209 // Proposal type counter update
210 proposal_type_iter.for_each(|proposal_type| {
211 *self.proposal_type_counter.entry(proposal_type).or_default() += 1;
212 });
213 }
214
215 identity_entry.or_insert(index);
216 credential_entry.or_insert(index);
217 hpke_entry.or_insert(index);
218
219 Ok(())
220 }
221
get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex>222 pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
223 self.identities.get(&Identifier(identity.to_vec())).copied()
224 }
225
remove(&mut self, leaf_node: &LeafNode, identity: &[u8])226 pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) {
227 let existed = self
228 .identities
229 .remove(&Identifier(identity.to_vec()))
230 .is_some();
231
232 self.credential_signature_key
233 .remove(&leaf_node.signing_identity.signature_key);
234
235 self.hpke_key.remove(&leaf_node.public_key);
236
237 if !existed {
238 return;
239 }
240
241 // Decrement credential type counters
242 let leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
243
244 if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) {
245 counters.used -= 1;
246 }
247
248 let credential_type_iter = leaf_node.capabilities.credentials.iter();
249
250 #[cfg(feature = "std")]
251 let credential_type_iter = credential_type_iter.unique();
252
253 #[cfg(not(feature = "std"))]
254 let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
255
256 credential_type_iter.for_each(|cred_type| {
257 if let Some(counters) = self.credential_type_counters.get_mut(cred_type) {
258 counters.supported -= 1;
259 }
260 });
261
262 #[cfg(feature = "custom_proposal")]
263 {
264 let proposal_type_iter = leaf_node.capabilities.proposals.iter();
265
266 #[cfg(feature = "std")]
267 let proposal_type_iter = proposal_type_iter.unique();
268
269 #[cfg(not(feature = "std"))]
270 let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
271
272 // Decrement proposal type counters
273 proposal_type_iter.for_each(|proposal_type| {
274 if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) {
275 *supported -= 1;
276 }
277 })
278 }
279 }
280
281 #[cfg(feature = "custom_proposal")]
count_supporting_proposal(&self, proposal_type: ProposalType) -> u32282 pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 {
283 self.proposal_type_counter
284 .get(&proposal_type)
285 .copied()
286 .unwrap_or_default()
287 }
288
289 #[cfg(test)]
len(&self) -> usize290 pub fn len(&self) -> usize {
291 self.credential_signature_key.len()
292 }
293 }
294
295 #[cfg(feature = "tree_index")]
296 #[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)]
297 struct TypeCounter {
298 supported: u32,
299 used: u32,
300 }
301
302 #[cfg(feature = "tree_index")]
303 #[cfg(test)]
304 mod tests {
305 use super::*;
306 use crate::{
307 client::test_utils::TEST_CIPHER_SUITE,
308 tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity},
309 };
310 use alloc::format;
311 use assert_matches::assert_matches;
312
313 #[derive(Clone, Debug)]
314 struct TestData {
315 pub leaf_node: LeafNode,
316 pub index: LeafIndex,
317 }
318
319 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_data(index: LeafIndex) -> TestData320 async fn get_test_data(index: LeafIndex) -> TestData {
321 let cipher_suite = TEST_CIPHER_SUITE;
322 let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await;
323
324 TestData { leaf_node, index }
325 }
326
327 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_setup() -> (Vec<TestData>, TreeIndex)328 async fn test_setup() -> (Vec<TestData>, TreeIndex) {
329 let mut test_data = Vec::new();
330
331 for i in 0..10 {
332 test_data.push(get_test_data(LeafIndex(i)).await);
333 }
334
335 let mut test_index = TreeIndex::new();
336
337 test_data.clone().into_iter().for_each(|d| {
338 test_index
339 .insert(
340 d.index,
341 &d.leaf_node,
342 get_test_client_identity(&d.leaf_node),
343 )
344 .unwrap()
345 });
346
347 (test_data, test_index)
348 }
349
350 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert()351 async fn test_insert() {
352 let (test_data, test_index) = test_setup().await;
353
354 assert_eq!(test_index.credential_signature_key.len(), test_data.len());
355 assert_eq!(test_index.hpke_key.len(), test_data.len());
356
357 test_data.into_iter().enumerate().for_each(|(i, d)| {
358 let pub_key = d.leaf_node.signing_identity.signature_key;
359
360 assert_eq!(
361 test_index.credential_signature_key.get(&pub_key),
362 Some(&LeafIndex(i as u32))
363 );
364
365 assert_eq!(
366 test_index.hpke_key.get(&d.leaf_node.public_key),
367 Some(&LeafIndex(i as u32))
368 );
369 })
370 }
371
372 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_credential_key()373 async fn test_insert_duplicate_credential_key() {
374 let (test_data, mut test_index) = test_setup().await;
375
376 let before_error = test_index.clone();
377
378 let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
379 new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone();
380
381 let res = test_index.insert(
382 test_data[1].index,
383 &new_key_package,
384 get_test_client_identity(&new_key_package),
385 );
386
387 assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
388 if index == *test_data[1].index);
389
390 assert_eq!(before_error, test_index);
391 }
392
393 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_hpke_key()394 async fn test_insert_duplicate_hpke_key() {
395 let cipher_suite = TEST_CIPHER_SUITE;
396 let (test_data, mut test_index) = test_setup().await;
397 let before_error = test_index.clone();
398
399 let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await;
400 new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone();
401
402 let res = test_index.insert(
403 test_data[1].index,
404 &new_leaf_node,
405 get_test_client_identity(&new_leaf_node),
406 );
407
408 assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
409 if index == *test_data[1].index);
410
411 assert_eq!(before_error, test_index);
412 }
413
414 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove()415 async fn test_remove() {
416 let (test_data, mut test_index) = test_setup().await;
417
418 test_index.remove(
419 &test_data[1].leaf_node,
420 &get_test_client_identity(&test_data[1].leaf_node),
421 );
422
423 assert_eq!(
424 test_index.credential_signature_key.len(),
425 test_data.len() - 1
426 );
427
428 assert_eq!(test_index.hpke_key.len(), test_data.len() - 1);
429
430 assert_eq!(
431 test_index
432 .credential_signature_key
433 .get(&test_data[1].leaf_node.signing_identity.signature_key),
434 None
435 );
436
437 assert_eq!(
438 test_index.hpke_key.get(&test_data[1].leaf_node.public_key),
439 None
440 );
441 }
442
443 #[cfg(feature = "custom_proposal")]
444 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
custom_proposals()445 async fn custom_proposals() {
446 let test_proposal_id = ProposalType::new(42);
447 let other_proposal_id = ProposalType::new(45);
448
449 let mut test_data_1 = get_test_data(LeafIndex(0)).await;
450
451 test_data_1
452 .leaf_node
453 .capabilities
454 .proposals
455 .push(test_proposal_id);
456
457 let mut test_data_2 = get_test_data(LeafIndex(1)).await;
458
459 test_data_2
460 .leaf_node
461 .capabilities
462 .proposals
463 .push(test_proposal_id);
464
465 test_data_2
466 .leaf_node
467 .capabilities
468 .proposals
469 .push(other_proposal_id);
470
471 let mut test_index = TreeIndex::new();
472
473 test_index
474 .insert(test_data_1.index, &test_data_1.leaf_node, vec![0])
475 .unwrap();
476
477 assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
478
479 test_index
480 .insert(test_data_2.index, &test_data_2.leaf_node, vec![1])
481 .unwrap();
482
483 assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2);
484 assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1);
485
486 test_index.remove(&test_data_2.leaf_node, &[1]);
487
488 assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
489 assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0);
490 }
491 }
492