• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::crypto::CipherSuiteProvider;
7 use crate::group::message_signature::{AuthenticatedContentTBS, FramedContentAuthData};
8 use crate::group::GroupContext;
9 use alloc::vec::Vec;
10 use core::{
11     fmt::{self, Debug},
12     ops::Deref,
13 };
14 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
15 use mls_rs_core::error::IntoAnyError;
16 
17 use super::message_signature::AuthenticatedContent;
18 
19 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
20 struct AuthenticatedContentTBM<'a> {
21     content_tbs: AuthenticatedContentTBS<'a>,
22     auth: &'a FramedContentAuthData,
23 }
24 
25 impl<'a> AuthenticatedContentTBM<'a> {
from_authenticated_content( auth_content: &'a AuthenticatedContent, group_context: &'a GroupContext, ) -> AuthenticatedContentTBM<'a>26     pub fn from_authenticated_content(
27         auth_content: &'a AuthenticatedContent,
28         group_context: &'a GroupContext,
29     ) -> AuthenticatedContentTBM<'a> {
30         AuthenticatedContentTBM {
31             content_tbs: AuthenticatedContentTBS::from_authenticated_content(
32                 auth_content,
33                 Some(group_context),
34                 group_context.protocol_version,
35             ),
36             auth: &auth_content.auth,
37         }
38     }
39 }
40 
41 #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
42 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
43 pub struct MembershipTag(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
44 
45 impl Debug for MembershipTag {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result46     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47         mls_rs_core::debug::pretty_bytes(&self.0)
48             .named("MembershipTag")
49             .fmt(f)
50     }
51 }
52 
53 impl Deref for MembershipTag {
54     type Target = Vec<u8>;
55 
deref(&self) -> &Self::Target56     fn deref(&self) -> &Self::Target {
57         &self.0
58     }
59 }
60 
61 impl From<Vec<u8>> for MembershipTag {
from(m: Vec<u8>) -> Self62     fn from(m: Vec<u8>) -> Self {
63         Self(m)
64     }
65 }
66 
67 impl MembershipTag {
68     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
create<P: CipherSuiteProvider>( authenticated_content: &AuthenticatedContent, group_context: &GroupContext, membership_key: &[u8], cipher_suite_provider: &P, ) -> Result<Self, MlsError>69     pub(crate) async fn create<P: CipherSuiteProvider>(
70         authenticated_content: &AuthenticatedContent,
71         group_context: &GroupContext,
72         membership_key: &[u8],
73         cipher_suite_provider: &P,
74     ) -> Result<Self, MlsError> {
75         let plaintext_tbm = AuthenticatedContentTBM::from_authenticated_content(
76             authenticated_content,
77             group_context,
78         );
79 
80         let serialized_tbm = plaintext_tbm.mls_encode_to_vec()?;
81 
82         let tag = cipher_suite_provider
83             .mac(membership_key, &serialized_tbm)
84             .await
85             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
86 
87         Ok(MembershipTag(tag))
88     }
89 }
90 
91 #[cfg(test)]
92 mod tests {
93     use super::*;
94     use crate::crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider};
95     use crate::group::{
96         framing::test_utils::get_test_auth_content, test_utils::get_test_group_context,
97     };
98 
99     #[cfg(not(mls_build_async))]
100     use crate::crypto::test_utils::TestCryptoProvider;
101 
102     #[cfg(target_arch = "wasm32")]
103     use wasm_bindgen_test::wasm_bindgen_test as test;
104 
105     #[derive(Debug, serde::Serialize, serde::Deserialize)]
106     struct TestCase {
107         cipher_suite: u16,
108         #[serde(with = "hex::serde")]
109         tag: Vec<u8>,
110     }
111 
112     #[cfg(not(mls_build_async))]
113     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_cases() -> Vec<TestCase>114     fn generate_test_cases() -> Vec<TestCase> {
115         let mut test_cases = Vec::new();
116 
117         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
118             let tag = MembershipTag::create(
119                 &get_test_auth_content(),
120                 &get_test_group_context(1, cipher_suite),
121                 b"membership_key".as_ref(),
122                 &test_cipher_suite_provider(cipher_suite),
123             )
124             .unwrap();
125 
126             test_cases.push(TestCase {
127                 cipher_suite: cipher_suite.into(),
128                 tag: tag.to_vec(),
129             });
130         }
131 
132         test_cases
133     }
134 
135     #[cfg(mls_build_async)]
generate_test_cases() -> Vec<TestCase>136     fn generate_test_cases() -> Vec<TestCase> {
137         panic!("Tests cannot be generated in async mode");
138     }
139 
load_test_cases() -> Vec<TestCase>140     fn load_test_cases() -> Vec<TestCase> {
141         load_test_case_json!(membership_tag, generate_test_cases())
142     }
143 
144     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_membership_tag()145     async fn test_membership_tag() {
146         for case in load_test_cases() {
147             let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
148                 continue;
149             };
150 
151             let tag = MembershipTag::create(
152                 &get_test_auth_content(),
153                 &get_test_group_context(1, cs_provider.cipher_suite()).await,
154                 b"membership_key".as_ref(),
155                 &test_cipher_suite_provider(cs_provider.cipher_suite()),
156             )
157             .await
158             .unwrap();
159 
160             assert_eq!(**tag, case.tag);
161         }
162     }
163 }
164