• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::cipher_suite::CipherSuite;
6 use crate::client::MlsError;
7 use crate::crypto::HpkePublicKey;
8 use crate::hash_reference::HashReference;
9 use crate::identity::SigningIdentity;
10 use crate::protocol_version::ProtocolVersion;
11 use crate::signer::Signable;
12 use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
13 use crate::CipherSuiteProvider;
14 use alloc::vec::Vec;
15 use core::{
16     fmt::{self, Debug},
17     ops::Deref,
18 };
19 use mls_rs_codec::MlsDecode;
20 use mls_rs_codec::MlsEncode;
21 use mls_rs_codec::MlsSize;
22 use mls_rs_core::extension::ExtensionList;
23 
24 mod validator;
25 pub(crate) use validator::*;
26 
27 pub(crate) mod generator;
28 pub(crate) use generator::*;
29 
30 #[non_exhaustive]
31 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
33 #[cfg_attr(
34     all(feature = "ffi", not(test)),
35     safer_ffi_gen::ffi_type(clone, opaque)
36 )]
37 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38 pub struct KeyPackage {
39     pub version: ProtocolVersion,
40     pub cipher_suite: CipherSuite,
41     pub hpke_init_key: HpkePublicKey,
42     pub(crate) leaf_node: LeafNode,
43     pub extensions: ExtensionList,
44     #[mls_codec(with = "mls_rs_codec::byte_vec")]
45     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
46     pub signature: Vec<u8>,
47 }
48 
49 impl Debug for KeyPackage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result50     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51         f.debug_struct("KeyPackage")
52             .field("version", &self.version)
53             .field("cipher_suite", &self.cipher_suite)
54             .field("hpke_init_key", &self.hpke_init_key)
55             .field("leaf_node", &self.leaf_node)
56             .field("extensions", &self.extensions)
57             .field(
58                 "signature",
59                 &mls_rs_core::debug::pretty_bytes(&self.signature),
60             )
61             .finish()
62     }
63 }
64 
65 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
66 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
67 #[cfg_attr(
68     all(feature = "ffi", not(test)),
69     safer_ffi_gen::ffi_type(clone, opaque)
70 )]
71 pub struct KeyPackageRef(HashReference);
72 
73 impl Deref for KeyPackageRef {
74     type Target = [u8];
75 
deref(&self) -> &Self::Target76     fn deref(&self) -> &Self::Target {
77         &self.0
78     }
79 }
80 
81 impl From<Vec<u8>> for KeyPackageRef {
from(v: Vec<u8>) -> Self82     fn from(v: Vec<u8>) -> Self {
83         Self(HashReference::from(v))
84     }
85 }
86 
87 #[derive(MlsSize, MlsEncode)]
88 struct KeyPackageData<'a> {
89     pub version: ProtocolVersion,
90     pub cipher_suite: CipherSuite,
91     #[mls_codec(with = "mls_rs_codec::byte_vec")]
92     pub hpke_init_key: &'a HpkePublicKey,
93     pub leaf_node: &'a LeafNode,
94     pub extensions: &'a ExtensionList,
95 }
96 
97 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
98 impl KeyPackage {
99     #[cfg(feature = "ffi")]
version(&self) -> ProtocolVersion100     pub fn version(&self) -> ProtocolVersion {
101         self.version
102     }
103 
104     #[cfg(feature = "ffi")]
cipher_suite(&self) -> CipherSuite105     pub fn cipher_suite(&self) -> CipherSuite {
106         self.cipher_suite
107     }
108 
signing_identity(&self) -> &SigningIdentity109     pub fn signing_identity(&self) -> &SigningIdentity {
110         &self.leaf_node.signing_identity
111     }
112 
113     #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
114     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
to_reference<CP: CipherSuiteProvider>( &self, cipher_suite_provider: &CP, ) -> Result<KeyPackageRef, MlsError>115     pub async fn to_reference<CP: CipherSuiteProvider>(
116         &self,
117         cipher_suite_provider: &CP,
118     ) -> Result<KeyPackageRef, MlsError> {
119         if cipher_suite_provider.cipher_suite() != self.cipher_suite {
120             return Err(MlsError::CipherSuiteMismatch);
121         }
122 
123         Ok(KeyPackageRef(
124             HashReference::compute(
125                 &self.mls_encode_to_vec()?,
126                 b"MLS 1.0 KeyPackage Reference",
127                 cipher_suite_provider,
128             )
129             .await?,
130         ))
131     }
132 
expiration(&self) -> Result<u64, MlsError>133     pub fn expiration(&self) -> Result<u64, MlsError> {
134         if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
135             Ok(lifetime.not_after)
136         } else {
137             Err(MlsError::InvalidLeafNodeSource)
138         }
139     }
140 }
141 
142 impl<'a> Signable<'a> for KeyPackage {
143     const SIGN_LABEL: &'static str = "KeyPackageTBS";
144 
145     type SigningContext = ();
146 
signature(&self) -> &[u8]147     fn signature(&self) -> &[u8] {
148         &self.signature
149     }
150 
signable_content( &self, _context: &Self::SigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>151     fn signable_content(
152         &self,
153         _context: &Self::SigningContext,
154     ) -> Result<Vec<u8>, mls_rs_codec::Error> {
155         KeyPackageData {
156             version: self.version,
157             cipher_suite: self.cipher_suite,
158             hpke_init_key: &self.hpke_init_key,
159             leaf_node: &self.leaf_node,
160             extensions: &self.extensions,
161         }
162         .mls_encode_to_vec()
163     }
164 
write_signature(&mut self, signature: Vec<u8>)165     fn write_signature(&mut self, signature: Vec<u8>) {
166         self.signature = signature
167     }
168 }
169 
170 #[cfg(test)]
171 pub(crate) mod test_utils {
172     use super::*;
173     use crate::{
174         crypto::test_utils::test_cipher_suite_provider,
175         group::framing::MlsMessagePayload,
176         identity::test_utils::get_test_signing_identity,
177         tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
178         MlsMessage,
179     };
180 
181     use mls_rs_core::crypto::SignatureSecretKey;
182 
183     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_key_package( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> KeyPackage184     pub(crate) async fn test_key_package(
185         protocol_version: ProtocolVersion,
186         cipher_suite: CipherSuite,
187         id: &str,
188     ) -> KeyPackage {
189         test_key_package_with_signer(protocol_version, cipher_suite, id)
190             .await
191             .0
192     }
193 
194     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_key_package_with_signer( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> (KeyPackage, SignatureSecretKey)195     pub(crate) async fn test_key_package_with_signer(
196         protocol_version: ProtocolVersion,
197         cipher_suite: CipherSuite,
198         id: &str,
199     ) -> (KeyPackage, SignatureSecretKey) {
200         let (signing_identity, secret_key) =
201             get_test_signing_identity(cipher_suite, id.as_bytes()).await;
202 
203         let generator = KeyPackageGenerator {
204             protocol_version,
205             cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
206             signing_identity: &signing_identity,
207             signing_key: &secret_key,
208         };
209 
210         let key_package = generator
211             .generate(
212                 Lifetime::years(1).unwrap(),
213                 get_test_capabilities(),
214                 ExtensionList::default(),
215                 ExtensionList::default(),
216             )
217             .await
218             .unwrap()
219             .key_package;
220 
221         (key_package, secret_key)
222     }
223 
224     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_key_package_message( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> MlsMessage225     pub(crate) async fn test_key_package_message(
226         protocol_version: ProtocolVersion,
227         cipher_suite: CipherSuite,
228         id: &str,
229     ) -> MlsMessage {
230         MlsMessage::new(
231             protocol_version,
232             MlsMessagePayload::KeyPackage(
233                 test_key_package(protocol_version, cipher_suite, id).await,
234             ),
235         )
236     }
237 }
238 
239 #[cfg(test)]
240 mod tests {
241     use crate::{
242         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
243         crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
244     };
245 
246     use super::{test_utils::test_key_package, *};
247     use alloc::format;
248     use assert_matches::assert_matches;
249 
250     #[derive(serde::Deserialize, serde::Serialize)]
251     struct TestCase {
252         cipher_suite: u16,
253         #[serde(with = "hex::serde")]
254         input: Vec<u8>,
255         #[serde(with = "hex::serde")]
256         output: Vec<u8>,
257     }
258 
259     impl TestCase {
260         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
261         #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestCase>262         async fn generate() -> Vec<TestCase> {
263             let mut test_cases = Vec::new();
264 
265             for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
266                 .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
267                 .enumerate()
268             {
269                 let pkg =
270                     test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
271 
272                 let pkg_ref = pkg
273                     .to_reference(&test_cipher_suite_provider(cipher_suite))
274                     .await
275                     .unwrap();
276 
277                 let case = TestCase {
278                     cipher_suite: cipher_suite.into(),
279                     input: pkg.mls_encode_to_vec().unwrap(),
280                     output: pkg_ref.to_vec(),
281                 };
282 
283                 test_cases.push(case);
284             }
285 
286             test_cases
287         }
288     }
289 
290     #[cfg(mls_build_async)]
load_test_cases() -> Vec<TestCase>291     async fn load_test_cases() -> Vec<TestCase> {
292         load_test_case_json!(key_package_ref, TestCase::generate().await)
293     }
294 
295     #[cfg(not(mls_build_async))]
load_test_cases() -> Vec<TestCase>296     fn load_test_cases() -> Vec<TestCase> {
297         load_test_case_json!(key_package_ref, TestCase::generate())
298     }
299 
300     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_package_ref()301     async fn test_key_package_ref() {
302         let cases = load_test_cases().await;
303 
304         for one_case in cases {
305             let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
306                 continue;
307             };
308 
309             let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
310 
311             let key_package_ref = key_package.to_reference(&provider).await.unwrap();
312 
313             let expected_out = KeyPackageRef::from(one_case.output);
314             assert_eq!(expected_out, key_package_ref);
315         }
316     }
317 
318     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
key_package_ref_fails_invalid_cipher_suite()319     async fn key_package_ref_fails_invalid_cipher_suite() {
320         let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
321 
322         for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
323             if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
324                 let res = key_package.to_reference(&cs).await;
325 
326                 assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
327             }
328         }
329     }
330 }
331