1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::client::MlsError; 6 use crate::crypto::CipherSuiteProvider; 7 use crate::group::message_signature::{AuthenticatedContentTBS, FramedContentAuthData}; 8 use crate::group::GroupContext; 9 use alloc::vec::Vec; 10 use core::{ 11 fmt::{self, Debug}, 12 ops::Deref, 13 }; 14 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 15 use mls_rs_core::error::IntoAnyError; 16 17 use super::message_signature::AuthenticatedContent; 18 19 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)] 20 struct AuthenticatedContentTBM<'a> { 21 content_tbs: AuthenticatedContentTBS<'a>, 22 auth: &'a FramedContentAuthData, 23 } 24 25 impl<'a> AuthenticatedContentTBM<'a> { from_authenticated_content( auth_content: &'a AuthenticatedContent, group_context: &'a GroupContext, ) -> AuthenticatedContentTBM<'a>26 pub fn from_authenticated_content( 27 auth_content: &'a AuthenticatedContent, 28 group_context: &'a GroupContext, 29 ) -> AuthenticatedContentTBM<'a> { 30 AuthenticatedContentTBM { 31 content_tbs: AuthenticatedContentTBS::from_authenticated_content( 32 auth_content, 33 Some(group_context), 34 group_context.protocol_version, 35 ), 36 auth: &auth_content.auth, 37 } 38 } 39 } 40 41 #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)] 42 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 43 pub struct MembershipTag(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>); 44 45 impl Debug for MembershipTag { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 47 mls_rs_core::debug::pretty_bytes(&self.0) 48 .named("MembershipTag") 49 .fmt(f) 50 } 51 } 52 53 impl Deref for MembershipTag { 54 type Target = Vec<u8>; 55 deref(&self) -> &Self::Target56 fn deref(&self) -> &Self::Target { 57 &self.0 58 } 59 } 60 61 impl From<Vec<u8>> for MembershipTag { from(m: Vec<u8>) -> Self62 fn from(m: Vec<u8>) -> Self { 63 Self(m) 64 } 65 } 66 67 impl MembershipTag { 68 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] create<P: CipherSuiteProvider>( authenticated_content: &AuthenticatedContent, group_context: &GroupContext, membership_key: &[u8], cipher_suite_provider: &P, ) -> Result<Self, MlsError>69 pub(crate) async fn create<P: CipherSuiteProvider>( 70 authenticated_content: &AuthenticatedContent, 71 group_context: &GroupContext, 72 membership_key: &[u8], 73 cipher_suite_provider: &P, 74 ) -> Result<Self, MlsError> { 75 let plaintext_tbm = AuthenticatedContentTBM::from_authenticated_content( 76 authenticated_content, 77 group_context, 78 ); 79 80 let serialized_tbm = plaintext_tbm.mls_encode_to_vec()?; 81 82 let tag = cipher_suite_provider 83 .mac(membership_key, &serialized_tbm) 84 .await 85 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 86 87 Ok(MembershipTag(tag)) 88 } 89 } 90 91 #[cfg(test)] 92 mod tests { 93 use super::*; 94 use crate::crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider}; 95 use crate::group::{ 96 framing::test_utils::get_test_auth_content, test_utils::get_test_group_context, 97 }; 98 99 #[cfg(not(mls_build_async))] 100 use crate::crypto::test_utils::TestCryptoProvider; 101 102 #[cfg(target_arch = "wasm32")] 103 use wasm_bindgen_test::wasm_bindgen_test as test; 104 105 #[derive(Debug, serde::Serialize, serde::Deserialize)] 106 struct TestCase { 107 cipher_suite: u16, 108 #[serde(with = "hex::serde")] 109 tag: Vec<u8>, 110 } 111 112 #[cfg(not(mls_build_async))] 113 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_cases() -> Vec<TestCase>114 fn generate_test_cases() -> Vec<TestCase> { 115 let mut test_cases = Vec::new(); 116 117 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 118 let tag = MembershipTag::create( 119 &get_test_auth_content(), 120 &get_test_group_context(1, cipher_suite), 121 b"membership_key".as_ref(), 122 &test_cipher_suite_provider(cipher_suite), 123 ) 124 .unwrap(); 125 126 test_cases.push(TestCase { 127 cipher_suite: cipher_suite.into(), 128 tag: tag.to_vec(), 129 }); 130 } 131 132 test_cases 133 } 134 135 #[cfg(mls_build_async)] generate_test_cases() -> Vec<TestCase>136 fn generate_test_cases() -> Vec<TestCase> { 137 panic!("Tests cannot be generated in async mode"); 138 } 139 load_test_cases() -> Vec<TestCase>140 fn load_test_cases() -> Vec<TestCase> { 141 load_test_case_json!(membership_tag, generate_test_cases()) 142 } 143 144 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_membership_tag()145 async fn test_membership_tag() { 146 for case in load_test_cases() { 147 let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else { 148 continue; 149 }; 150 151 let tag = MembershipTag::create( 152 &get_test_auth_content(), 153 &get_test_group_context(1, cs_provider.cipher_suite()).await, 154 b"membership_key".as_ref(), 155 &test_cipher_suite_provider(cs_provider.cipher_suite()), 156 ) 157 .await 158 .unwrap(); 159 160 assert_eq!(**tag, case.tag); 161 } 162 } 163 } 164