// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use super::*; #[cfg(feature = "tree_index")] use core::fmt::{self, Debug}; #[cfg(all(feature = "tree_index", feature = "custom_proposal"))] use crate::group::proposal::ProposalType; #[cfg(feature = "tree_index")] use crate::identity::CredentialType; #[cfg(feature = "tree_index")] use mls_rs_core::crypto::SignaturePublicKey; #[cfg(all(feature = "tree_index", feature = "std"))] use itertools::Itertools; #[cfg(all(feature = "tree_index", not(feature = "std")))] use alloc::collections::{btree_map::Entry, BTreeMap}; #[cfg(all(feature = "tree_index", feature = "std"))] use std::collections::{hash_map::Entry, HashMap}; #[cfg(all(feature = "tree_index", not(feature = "std")))] use alloc::collections::BTreeSet; #[cfg(feature = "tree_index")] use mls_rs_core::crypto::HpkePublicKey; #[cfg(feature = "tree_index")] #[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)] pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec); #[cfg(feature = "tree_index")] impl Debug for Identifier { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { mls_rs_core::debug::pretty_bytes(&self.0) .named("Identifier") .fmt(f) } } #[cfg(all(feature = "tree_index", feature = "std"))] #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)] pub struct TreeIndex { credential_signature_key: HashMap, hpke_key: HashMap, identities: HashMap, credential_type_counters: HashMap, #[cfg(feature = "custom_proposal")] proposal_type_counter: HashMap, } #[cfg(all(feature = "tree_index", not(feature = "std")))] #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)] pub struct TreeIndex { credential_signature_key: BTreeMap, hpke_key: BTreeMap, identities: BTreeMap, credential_type_counters: BTreeMap, #[cfg(feature = "custom_proposal")] proposal_type_counter: BTreeMap, } #[cfg(feature = "tree_index")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(super) async fn index_insert( tree_index: &mut TreeIndex, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError> { let new_id = id_provider .identity(&new_leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; tree_index.insert(new_leaf_idx, new_leaf, new_id) } #[cfg(not(feature = "tree_index"))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(super) async fn index_insert( nodes: &NodeVec, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError> { let new_id = id_provider .identity(&new_leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) { (new_leaf.public_key != leaf.public_key) .then_some(()) .ok_or(MlsError::DuplicateLeafData(*i))?; (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key) .then_some(()) .ok_or(MlsError::DuplicateLeafData(*i))?; let id = id_provider .identity(&leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; (new_id != id) .then_some(()) .ok_or(MlsError::DuplicateLeafData(*i))?; let cred_type = leaf.signing_identity.credential.credential_type(); new_leaf .capabilities .credentials .contains(&cred_type) .then_some(()) .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?; let new_cred_type = new_leaf.signing_identity.credential.credential_type(); leaf.capabilities .credentials .contains(&new_cred_type) .then_some(()) .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?; } Ok(()) } #[cfg(feature = "tree_index")] impl TreeIndex { pub fn new() -> Self { Default::default() } pub fn is_initialized(&self) -> bool { !self.identities.is_empty() } fn insert( &mut self, index: LeafIndex, leaf_node: &LeafNode, identity: Vec, ) -> Result<(), MlsError> { let old_leaf_count = self.credential_signature_key.len(); let pub_key = leaf_node.signing_identity.signature_key.clone(); let credential_entry = self.credential_signature_key.entry(pub_key); if let Entry::Occupied(entry) = credential_entry { return Err(MlsError::DuplicateLeafData(**entry.get())); } let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone()); if let Entry::Occupied(entry) = hpke_entry { return Err(MlsError::DuplicateLeafData(**entry.get())); } let identity_entry = self.identities.entry(Identifier(identity)); if let Entry::Occupied(entry) = identity_entry { return Err(MlsError::DuplicateLeafData(**entry.get())); } let in_use_cred_type_unsupported_by_new_leaf = self .credential_type_counters .iter() .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0)) .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type)); if in_use_cred_type_unsupported_by_new_leaf.is_some() { return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf); } let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type(); let cred_type_counters = self .credential_type_counters .entry(new_leaf_cred_type) .or_default(); if cred_type_counters.supported != old_leaf_count as u32 { return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported); } cred_type_counters.used += 1; let credential_type_iter = leaf_node.capabilities.credentials.iter().copied(); #[cfg(feature = "std")] let credential_type_iter = credential_type_iter.unique(); #[cfg(not(feature = "std"))] let credential_type_iter = credential_type_iter.collect::>().into_iter(); // Credential type counter updates credential_type_iter.for_each(|cred_type| { self.credential_type_counters .entry(cred_type) .or_default() .supported += 1; }); #[cfg(feature = "custom_proposal")] { let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied(); #[cfg(feature = "std")] let proposal_type_iter = proposal_type_iter.unique(); #[cfg(not(feature = "std"))] let proposal_type_iter = proposal_type_iter.collect::>().into_iter(); // Proposal type counter update proposal_type_iter.for_each(|proposal_type| { *self.proposal_type_counter.entry(proposal_type).or_default() += 1; }); } identity_entry.or_insert(index); credential_entry.or_insert(index); hpke_entry.or_insert(index); Ok(()) } pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option { self.identities.get(&Identifier(identity.to_vec())).copied() } pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) { let existed = self .identities .remove(&Identifier(identity.to_vec())) .is_some(); self.credential_signature_key .remove(&leaf_node.signing_identity.signature_key); self.hpke_key.remove(&leaf_node.public_key); if !existed { return; } // Decrement credential type counters let leaf_cred_type = leaf_node.signing_identity.credential.credential_type(); if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) { counters.used -= 1; } let credential_type_iter = leaf_node.capabilities.credentials.iter(); #[cfg(feature = "std")] let credential_type_iter = credential_type_iter.unique(); #[cfg(not(feature = "std"))] let credential_type_iter = credential_type_iter.collect::>().into_iter(); credential_type_iter.for_each(|cred_type| { if let Some(counters) = self.credential_type_counters.get_mut(cred_type) { counters.supported -= 1; } }); #[cfg(feature = "custom_proposal")] { let proposal_type_iter = leaf_node.capabilities.proposals.iter(); #[cfg(feature = "std")] let proposal_type_iter = proposal_type_iter.unique(); #[cfg(not(feature = "std"))] let proposal_type_iter = proposal_type_iter.collect::>().into_iter(); // Decrement proposal type counters proposal_type_iter.for_each(|proposal_type| { if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) { *supported -= 1; } }) } } #[cfg(feature = "custom_proposal")] pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 { self.proposal_type_counter .get(&proposal_type) .copied() .unwrap_or_default() } #[cfg(test)] pub fn len(&self) -> usize { self.credential_signature_key.len() } } #[cfg(feature = "tree_index")] #[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)] struct TypeCounter { supported: u32, used: u32, } #[cfg(feature = "tree_index")] #[cfg(test)] mod tests { use super::*; use crate::{ client::test_utils::TEST_CIPHER_SUITE, tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity}, }; use alloc::format; use assert_matches::assert_matches; #[derive(Clone, Debug)] struct TestData { pub leaf_node: LeafNode, pub index: LeafIndex, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn get_test_data(index: LeafIndex) -> TestData { let cipher_suite = TEST_CIPHER_SUITE; let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await; TestData { leaf_node, index } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn test_setup() -> (Vec, TreeIndex) { let mut test_data = Vec::new(); for i in 0..10 { test_data.push(get_test_data(LeafIndex(i)).await); } let mut test_index = TreeIndex::new(); test_data.clone().into_iter().for_each(|d| { test_index .insert( d.index, &d.leaf_node, get_test_client_identity(&d.leaf_node), ) .unwrap() }); (test_data, test_index) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_insert() { let (test_data, test_index) = test_setup().await; assert_eq!(test_index.credential_signature_key.len(), test_data.len()); assert_eq!(test_index.hpke_key.len(), test_data.len()); test_data.into_iter().enumerate().for_each(|(i, d)| { let pub_key = d.leaf_node.signing_identity.signature_key; assert_eq!( test_index.credential_signature_key.get(&pub_key), Some(&LeafIndex(i as u32)) ); assert_eq!( test_index.hpke_key.get(&d.leaf_node.public_key), Some(&LeafIndex(i as u32)) ); }) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_insert_duplicate_credential_key() { let (test_data, mut test_index) = test_setup().await; let before_error = test_index.clone(); let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await; new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone(); let res = test_index.insert( test_data[1].index, &new_key_package, get_test_client_identity(&new_key_package), ); assert_matches!(res, Err(MlsError::DuplicateLeafData(index)) if index == *test_data[1].index); assert_eq!(before_error, test_index); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_insert_duplicate_hpke_key() { let cipher_suite = TEST_CIPHER_SUITE; let (test_data, mut test_index) = test_setup().await; let before_error = test_index.clone(); let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await; new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone(); let res = test_index.insert( test_data[1].index, &new_leaf_node, get_test_client_identity(&new_leaf_node), ); assert_matches!(res, Err(MlsError::DuplicateLeafData(index)) if index == *test_data[1].index); assert_eq!(before_error, test_index); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_remove() { let (test_data, mut test_index) = test_setup().await; test_index.remove( &test_data[1].leaf_node, &get_test_client_identity(&test_data[1].leaf_node), ); assert_eq!( test_index.credential_signature_key.len(), test_data.len() - 1 ); assert_eq!(test_index.hpke_key.len(), test_data.len() - 1); assert_eq!( test_index .credential_signature_key .get(&test_data[1].leaf_node.signing_identity.signature_key), None ); assert_eq!( test_index.hpke_key.get(&test_data[1].leaf_node.public_key), None ); } #[cfg(feature = "custom_proposal")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn custom_proposals() { let test_proposal_id = ProposalType::new(42); let other_proposal_id = ProposalType::new(45); let mut test_data_1 = get_test_data(LeafIndex(0)).await; test_data_1 .leaf_node .capabilities .proposals .push(test_proposal_id); let mut test_data_2 = get_test_data(LeafIndex(1)).await; test_data_2 .leaf_node .capabilities .proposals .push(test_proposal_id); test_data_2 .leaf_node .capabilities .proposals .push(other_proposal_id); let mut test_index = TreeIndex::new(); test_index .insert(test_data_1.index, &test_data_1.leaf_node, vec![0]) .unwrap(); assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1); test_index .insert(test_data_2.index, &test_data_2.leaf_node, vec![1]) .unwrap(); assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2); assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1); test_index.remove(&test_data_2.leaf_node, &[1]); assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1); assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0); } }