• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use super::*;
6 #[cfg(feature = "tree_index")]
7 use core::fmt::{self, Debug};
8 
9 #[cfg(all(feature = "tree_index", feature = "custom_proposal"))]
10 use crate::group::proposal::ProposalType;
11 
12 #[cfg(feature = "tree_index")]
13 use crate::{
14     identity::CredentialType,
15     map::{LargeMap, LargeMapEntry},
16 };
17 
18 #[cfg(feature = "tree_index")]
19 use mls_rs_core::crypto::SignaturePublicKey;
20 
21 #[cfg(all(feature = "tree_index", feature = "std"))]
22 use itertools::Itertools;
23 
24 #[cfg(all(feature = "tree_index", not(feature = "std")))]
25 use alloc::collections::BTreeSet;
26 
27 #[cfg(feature = "tree_index")]
28 use mls_rs_core::crypto::HpkePublicKey;
29 
30 #[cfg(feature = "tree_index")]
31 #[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)]
32 pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
33 
34 #[cfg(feature = "tree_index")]
35 impl Debug for Identifier {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result36     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37         mls_rs_core::debug::pretty_bytes(&self.0)
38             .named("Identifier")
39             .fmt(f)
40     }
41 }
42 
43 #[cfg(feature = "tree_index")]
44 #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
45 pub struct TreeIndex {
46     credential_signature_key: LargeMap<SignaturePublicKey, LeafIndex>,
47     hpke_key: LargeMap<HpkePublicKey, LeafIndex>,
48     identities: LargeMap<Identifier, LeafIndex>,
49     credential_type_counters: LargeMap<CredentialType, TypeCounter>,
50     #[cfg(feature = "custom_proposal")]
51     proposal_type_counter: LargeMap<ProposalType, u32>,
52 }
53 
54 #[cfg(feature = "tree_index")]
55 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( tree_index: &mut TreeIndex, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>56 pub(super) async fn index_insert<I: IdentityProvider>(
57     tree_index: &mut TreeIndex,
58     new_leaf: &LeafNode,
59     new_leaf_idx: LeafIndex,
60     id_provider: &I,
61     extensions: &ExtensionList,
62 ) -> Result<(), MlsError> {
63     let new_id = id_provider
64         .identity(&new_leaf.signing_identity, extensions)
65         .await
66         .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
67 
68     tree_index.insert(new_leaf_idx, new_leaf, new_id)
69 }
70 
71 #[cfg(not(feature = "tree_index"))]
72 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( nodes: &NodeVec, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>73 pub(super) async fn index_insert<I: IdentityProvider>(
74     nodes: &NodeVec,
75     new_leaf: &LeafNode,
76     new_leaf_idx: LeafIndex,
77     id_provider: &I,
78     extensions: &ExtensionList,
79 ) -> Result<(), MlsError> {
80     let new_id = id_provider
81         .identity(&new_leaf.signing_identity, extensions)
82         .await
83         .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
84 
85     for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) {
86         (new_leaf.public_key != leaf.public_key)
87             .then_some(())
88             .ok_or(MlsError::DuplicateLeafData(*i))?;
89 
90         (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key)
91             .then_some(())
92             .ok_or(MlsError::DuplicateLeafData(*i))?;
93 
94         let id = id_provider
95             .identity(&leaf.signing_identity, extensions)
96             .await
97             .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
98 
99         (new_id != id)
100             .then_some(())
101             .ok_or(MlsError::DuplicateLeafData(*i))?;
102 
103         let cred_type = leaf.signing_identity.credential.credential_type();
104 
105         new_leaf
106             .capabilities
107             .credentials
108             .contains(&cred_type)
109             .then_some(())
110             .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?;
111 
112         let new_cred_type = new_leaf.signing_identity.credential.credential_type();
113 
114         leaf.capabilities
115             .credentials
116             .contains(&new_cred_type)
117             .then_some(())
118             .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?;
119     }
120 
121     Ok(())
122 }
123 
124 #[cfg(feature = "tree_index")]
125 impl TreeIndex {
new() -> Self126     pub fn new() -> Self {
127         Default::default()
128     }
129 
is_initialized(&self) -> bool130     pub fn is_initialized(&self) -> bool {
131         !self.identities.is_empty()
132     }
133 
insert( &mut self, index: LeafIndex, leaf_node: &LeafNode, identity: Vec<u8>, ) -> Result<(), MlsError>134     fn insert(
135         &mut self,
136         index: LeafIndex,
137         leaf_node: &LeafNode,
138         identity: Vec<u8>,
139     ) -> Result<(), MlsError> {
140         let old_leaf_count = self.credential_signature_key.len();
141 
142         let pub_key = leaf_node.signing_identity.signature_key.clone();
143         let credential_entry = self.credential_signature_key.entry(pub_key);
144 
145         if let LargeMapEntry::Occupied(entry) = credential_entry {
146             return Err(MlsError::DuplicateLeafData(**entry.get()));
147         }
148 
149         let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone());
150 
151         if let LargeMapEntry::Occupied(entry) = hpke_entry {
152             return Err(MlsError::DuplicateLeafData(**entry.get()));
153         }
154 
155         let identity_entry = self.identities.entry(Identifier(identity));
156         if let LargeMapEntry::Occupied(entry) = identity_entry {
157             return Err(MlsError::DuplicateLeafData(**entry.get()));
158         }
159 
160         let in_use_cred_type_unsupported_by_new_leaf = self
161             .credential_type_counters
162             .iter()
163             .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0))
164             .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type));
165 
166         if in_use_cred_type_unsupported_by_new_leaf.is_some() {
167             return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf);
168         }
169 
170         let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
171 
172         let cred_type_counters = self
173             .credential_type_counters
174             .entry(new_leaf_cred_type)
175             .or_default();
176 
177         if cred_type_counters.supported != old_leaf_count as u32 {
178             return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported);
179         }
180 
181         cred_type_counters.used += 1;
182 
183         let credential_type_iter = leaf_node.capabilities.credentials.iter().copied();
184 
185         #[cfg(feature = "std")]
186         let credential_type_iter = credential_type_iter.unique();
187 
188         #[cfg(not(feature = "std"))]
189         let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
190 
191         // Credential type counter updates
192         credential_type_iter.for_each(|cred_type| {
193             self.credential_type_counters
194                 .entry(cred_type)
195                 .or_default()
196                 .supported += 1;
197         });
198 
199         #[cfg(feature = "custom_proposal")]
200         {
201             let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied();
202 
203             #[cfg(feature = "std")]
204             let proposal_type_iter = proposal_type_iter.unique();
205 
206             #[cfg(not(feature = "std"))]
207             let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
208 
209             // Proposal type counter update
210             proposal_type_iter.for_each(|proposal_type| {
211                 *self.proposal_type_counter.entry(proposal_type).or_default() += 1;
212             });
213         }
214 
215         identity_entry.or_insert(index);
216         credential_entry.or_insert(index);
217         hpke_entry.or_insert(index);
218 
219         Ok(())
220     }
221 
get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex>222     pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
223         self.identities.get(&Identifier(identity.to_vec())).copied()
224     }
225 
remove(&mut self, leaf_node: &LeafNode, identity: &[u8])226     pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) {
227         let existed = self
228             .identities
229             .remove(&Identifier(identity.to_vec()))
230             .is_some();
231 
232         self.credential_signature_key
233             .remove(&leaf_node.signing_identity.signature_key);
234 
235         self.hpke_key.remove(&leaf_node.public_key);
236 
237         if !existed {
238             return;
239         }
240 
241         // Decrement credential type counters
242         let leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
243 
244         if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) {
245             counters.used -= 1;
246         }
247 
248         let credential_type_iter = leaf_node.capabilities.credentials.iter();
249 
250         #[cfg(feature = "std")]
251         let credential_type_iter = credential_type_iter.unique();
252 
253         #[cfg(not(feature = "std"))]
254         let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
255 
256         credential_type_iter.for_each(|cred_type| {
257             if let Some(counters) = self.credential_type_counters.get_mut(cred_type) {
258                 counters.supported -= 1;
259             }
260         });
261 
262         #[cfg(feature = "custom_proposal")]
263         {
264             let proposal_type_iter = leaf_node.capabilities.proposals.iter();
265 
266             #[cfg(feature = "std")]
267             let proposal_type_iter = proposal_type_iter.unique();
268 
269             #[cfg(not(feature = "std"))]
270             let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
271 
272             // Decrement proposal type counters
273             proposal_type_iter.for_each(|proposal_type| {
274                 if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) {
275                     *supported -= 1;
276                 }
277             })
278         }
279     }
280 
281     #[cfg(feature = "custom_proposal")]
count_supporting_proposal(&self, proposal_type: ProposalType) -> u32282     pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 {
283         self.proposal_type_counter
284             .get(&proposal_type)
285             .copied()
286             .unwrap_or_default()
287     }
288 
289     #[cfg(test)]
len(&self) -> usize290     pub fn len(&self) -> usize {
291         self.credential_signature_key.len()
292     }
293 }
294 
295 #[cfg(feature = "tree_index")]
296 #[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)]
297 struct TypeCounter {
298     supported: u32,
299     used: u32,
300 }
301 
302 #[cfg(feature = "tree_index")]
303 #[cfg(test)]
304 mod tests {
305     use super::*;
306     use crate::{
307         client::test_utils::TEST_CIPHER_SUITE,
308         tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity},
309     };
310     use alloc::format;
311     use assert_matches::assert_matches;
312 
313     #[derive(Clone, Debug)]
314     struct TestData {
315         pub leaf_node: LeafNode,
316         pub index: LeafIndex,
317     }
318 
319     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_data(index: LeafIndex) -> TestData320     async fn get_test_data(index: LeafIndex) -> TestData {
321         let cipher_suite = TEST_CIPHER_SUITE;
322         let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await;
323 
324         TestData { leaf_node, index }
325     }
326 
327     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_setup() -> (Vec<TestData>, TreeIndex)328     async fn test_setup() -> (Vec<TestData>, TreeIndex) {
329         let mut test_data = Vec::new();
330 
331         for i in 0..10 {
332             test_data.push(get_test_data(LeafIndex(i)).await);
333         }
334 
335         let mut test_index = TreeIndex::new();
336 
337         test_data.clone().into_iter().for_each(|d| {
338             test_index
339                 .insert(
340                     d.index,
341                     &d.leaf_node,
342                     get_test_client_identity(&d.leaf_node),
343                 )
344                 .unwrap()
345         });
346 
347         (test_data, test_index)
348     }
349 
350     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert()351     async fn test_insert() {
352         let (test_data, test_index) = test_setup().await;
353 
354         assert_eq!(test_index.credential_signature_key.len(), test_data.len());
355         assert_eq!(test_index.hpke_key.len(), test_data.len());
356 
357         test_data.into_iter().enumerate().for_each(|(i, d)| {
358             let pub_key = d.leaf_node.signing_identity.signature_key;
359 
360             assert_eq!(
361                 test_index.credential_signature_key.get(&pub_key),
362                 Some(&LeafIndex(i as u32))
363             );
364 
365             assert_eq!(
366                 test_index.hpke_key.get(&d.leaf_node.public_key),
367                 Some(&LeafIndex(i as u32))
368             );
369         })
370     }
371 
372     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_credential_key()373     async fn test_insert_duplicate_credential_key() {
374         let (test_data, mut test_index) = test_setup().await;
375 
376         let before_error = test_index.clone();
377 
378         let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
379         new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone();
380 
381         let res = test_index.insert(
382             test_data[1].index,
383             &new_key_package,
384             get_test_client_identity(&new_key_package),
385         );
386 
387         assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
388                         if index == *test_data[1].index);
389 
390         assert_eq!(before_error, test_index);
391     }
392 
393     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_hpke_key()394     async fn test_insert_duplicate_hpke_key() {
395         let cipher_suite = TEST_CIPHER_SUITE;
396         let (test_data, mut test_index) = test_setup().await;
397         let before_error = test_index.clone();
398 
399         let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await;
400         new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone();
401 
402         let res = test_index.insert(
403             test_data[1].index,
404             &new_leaf_node,
405             get_test_client_identity(&new_leaf_node),
406         );
407 
408         assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
409                         if index == *test_data[1].index);
410 
411         assert_eq!(before_error, test_index);
412     }
413 
414     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove()415     async fn test_remove() {
416         let (test_data, mut test_index) = test_setup().await;
417 
418         test_index.remove(
419             &test_data[1].leaf_node,
420             &get_test_client_identity(&test_data[1].leaf_node),
421         );
422 
423         assert_eq!(
424             test_index.credential_signature_key.len(),
425             test_data.len() - 1
426         );
427 
428         assert_eq!(test_index.hpke_key.len(), test_data.len() - 1);
429 
430         assert_eq!(
431             test_index
432                 .credential_signature_key
433                 .get(&test_data[1].leaf_node.signing_identity.signature_key),
434             None
435         );
436 
437         assert_eq!(
438             test_index.hpke_key.get(&test_data[1].leaf_node.public_key),
439             None
440         );
441     }
442 
443     #[cfg(feature = "custom_proposal")]
444     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
custom_proposals()445     async fn custom_proposals() {
446         let test_proposal_id = ProposalType::new(42);
447         let other_proposal_id = ProposalType::new(45);
448 
449         let mut test_data_1 = get_test_data(LeafIndex(0)).await;
450 
451         test_data_1
452             .leaf_node
453             .capabilities
454             .proposals
455             .push(test_proposal_id);
456 
457         let mut test_data_2 = get_test_data(LeafIndex(1)).await;
458 
459         test_data_2
460             .leaf_node
461             .capabilities
462             .proposals
463             .push(test_proposal_id);
464 
465         test_data_2
466             .leaf_node
467             .capabilities
468             .proposals
469             .push(other_proposal_id);
470 
471         let mut test_index = TreeIndex::new();
472 
473         test_index
474             .insert(test_data_1.index, &test_data_1.leaf_node, vec![0])
475             .unwrap();
476 
477         assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
478 
479         test_index
480             .insert(test_data_2.index, &test_data_2.leaf_node, vec![1])
481             .unwrap();
482 
483         assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2);
484         assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1);
485 
486         test_index.remove(&test_data_2.leaf_node, &[1]);
487 
488         assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
489         assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0);
490     }
491 }
492