• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 use mls_rs_core::{error::IntoAnyError, key_package::KeyPackageData};
9 
10 use crate::client::MlsError;
11 use crate::{
12     crypto::{HpkeSecretKey, SignatureSecretKey},
13     group::framing::MlsMessagePayload,
14     identity::SigningIdentity,
15     protocol_version::ProtocolVersion,
16     signer::Signable,
17     tree_kem::{
18         leaf_node::{ConfigProperties, LeafNode},
19         Capabilities, Lifetime,
20     },
21     CipherSuiteProvider, ExtensionList, MlsMessage,
22 };
23 
24 use super::{KeyPackage, KeyPackageRef};
25 
26 #[derive(Clone, Debug)]
27 pub struct KeyPackageGenerator<'a, CP>
28 where
29     CP: CipherSuiteProvider,
30 {
31     pub protocol_version: ProtocolVersion,
32     pub cipher_suite_provider: &'a CP,
33     pub signing_identity: &'a SigningIdentity,
34     pub signing_key: &'a SignatureSecretKey,
35 }
36 
37 #[derive(Clone, Debug)]
38 pub struct KeyPackageGeneration {
39     pub(crate) reference: KeyPackageRef,
40     pub(crate) key_package: KeyPackage,
41     pub(crate) init_secret_key: HpkeSecretKey,
42     pub(crate) leaf_node_secret_key: HpkeSecretKey,
43 }
44 
45 impl KeyPackageGeneration {
to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError>46     pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> {
47         let id = self.reference.to_vec();
48 
49         let data = KeyPackageData::new(
50             self.key_package.mls_encode_to_vec()?,
51             self.init_secret_key.clone(),
52             self.leaf_node_secret_key.clone(),
53             self.key_package.expiration()?,
54         );
55 
56         Ok((id, data))
57     }
58 
from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError>59     pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> {
60         Ok(KeyPackageGeneration {
61             reference: KeyPackageRef::from(id),
62             key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?,
63             init_secret_key: data.init_key,
64             leaf_node_secret_key: data.leaf_node_key,
65         })
66     }
67 
key_package_message(&self) -> MlsMessage68     pub fn key_package_message(&self) -> MlsMessage {
69         MlsMessage::new(
70             self.key_package.version,
71             MlsMessagePayload::KeyPackage(self.key_package.clone()),
72         )
73     }
74 }
75 
76 impl<'a, CP> KeyPackageGenerator<'a, CP>
77 where
78     CP: CipherSuiteProvider,
79 {
80     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
sign(&self, package: &mut KeyPackage) -> Result<(), MlsError>81     pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> {
82         package
83             .sign(self.cipher_suite_provider, self.signing_key, &())
84             .await
85     }
86 
87     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
generate( &self, lifetime: Lifetime, capabilities: Capabilities, key_package_extensions: ExtensionList, leaf_node_extensions: ExtensionList, ) -> Result<KeyPackageGeneration, MlsError>88     pub async fn generate(
89         &self,
90         lifetime: Lifetime,
91         capabilities: Capabilities,
92         key_package_extensions: ExtensionList,
93         leaf_node_extensions: ExtensionList,
94     ) -> Result<KeyPackageGeneration, MlsError> {
95         let (init_secret_key, public_init) = self
96             .cipher_suite_provider
97             .kem_generate()
98             .await
99             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
100 
101         let properties = ConfigProperties {
102             capabilities,
103             extensions: leaf_node_extensions,
104         };
105 
106         let (leaf_node, leaf_node_secret) = LeafNode::generate(
107             self.cipher_suite_provider,
108             properties,
109             self.signing_identity.clone(),
110             self.signing_key,
111             lifetime,
112         )
113         .await?;
114 
115         let mut package = KeyPackage {
116             version: self.protocol_version,
117             cipher_suite: self.cipher_suite_provider.cipher_suite(),
118             hpke_init_key: public_init,
119             leaf_node,
120             extensions: key_package_extensions,
121             signature: vec![],
122         };
123 
124         package.grease(self.cipher_suite_provider)?;
125 
126         self.sign(&mut package).await?;
127 
128         let reference = package.to_reference(self.cipher_suite_provider).await?;
129 
130         Ok(KeyPackageGeneration {
131             key_package: package,
132             init_secret_key,
133             leaf_node_secret_key: leaf_node_secret,
134             reference,
135         })
136     }
137 }
138 
139 #[cfg(test)]
140 mod tests {
141     use assert_matches::assert_matches;
142     use mls_rs_core::crypto::CipherSuiteProvider;
143 
144     use crate::{
145         crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
146         extension::test_utils::TestExtension,
147         group::test_utils::random_bytes,
148         identity::basic::BasicIdentityProvider,
149         identity::test_utils::get_test_signing_identity,
150         key_package::validate_key_package_properties,
151         protocol_version::ProtocolVersion,
152         tree_kem::{
153             leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
154             leaf_node_validator::{LeafNodeValidator, ValidationContext},
155             Lifetime,
156         },
157         ExtensionList,
158     };
159 
160     use super::KeyPackageGenerator;
161 
test_key_package_ext(val: u8) -> ExtensionList162     fn test_key_package_ext(val: u8) -> ExtensionList {
163         let mut ext_list = ExtensionList::new();
164         ext_list.set_from(TestExtension::from(val)).unwrap();
165         ext_list
166     }
167 
test_leaf_node_ext(val: u8) -> ExtensionList168     fn test_leaf_node_ext(val: u8) -> ExtensionList {
169         let mut ext_list = ExtensionList::new();
170         ext_list.set_from(TestExtension::from(val)).unwrap();
171         ext_list
172     }
173 
test_lifetime() -> Lifetime174     fn test_lifetime() -> Lifetime {
175         Lifetime::years(1).unwrap()
176     }
177 
178     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_generation()179     async fn test_key_generation() {
180         for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
181             TestCryptoProvider::all_supported_cipher_suites()
182                 .into_iter()
183                 .map(move |cs| (p, cs))
184         }) {
185             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
186 
187             let (signing_identity, signing_key) =
188                 get_test_signing_identity(cipher_suite, b"foo").await;
189 
190             let key_package_ext = test_key_package_ext(32);
191             let leaf_node_ext = test_leaf_node_ext(42);
192             let lifetime = test_lifetime();
193 
194             let test_generator = KeyPackageGenerator {
195                 protocol_version,
196                 cipher_suite_provider: &cipher_suite_provider,
197                 signing_identity: &signing_identity,
198                 signing_key: &signing_key,
199             };
200 
201             let mut capabilities = get_test_capabilities();
202             capabilities.extensions.push(42.into());
203             capabilities.extensions.push(43.into());
204             capabilities.extensions.push(32.into());
205 
206             let generated = test_generator
207                 .generate(
208                     lifetime.clone(),
209                     capabilities.clone(),
210                     key_package_ext.clone(),
211                     leaf_node_ext.clone(),
212                 )
213                 .await
214                 .unwrap();
215 
216             assert_matches!(generated.key_package.leaf_node.leaf_node_source,
217                             LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime);
218 
219             assert_eq!(
220                 generated.key_package.leaf_node.ungreased_capabilities(),
221                 capabilities
222             );
223 
224             assert_eq!(
225                 generated.key_package.leaf_node.ungreased_extensions(),
226                 leaf_node_ext
227             );
228 
229             assert_eq!(
230                 generated.key_package.ungreased_extensions(),
231                 key_package_ext
232             );
233 
234             assert_ne!(
235                 generated.key_package.hpke_init_key.as_ref(),
236                 generated.key_package.leaf_node.public_key.as_ref()
237             );
238 
239             assert_eq!(generated.key_package.cipher_suite, cipher_suite);
240             assert_eq!(generated.key_package.version, protocol_version);
241 
242             // Verify that the hpke key pair generated will work
243             let test_data = random_bytes(32);
244 
245             let sealed = cipher_suite_provider
246                 .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data)
247                 .await
248                 .unwrap();
249 
250             let opened = cipher_suite_provider
251                 .hpke_open(
252                     &sealed,
253                     &generated.init_secret_key,
254                     &generated.key_package.hpke_init_key,
255                     &[],
256                     None,
257                 )
258                 .await
259                 .unwrap();
260 
261             assert_eq!(opened, test_data);
262 
263             let validator =
264                 LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
265 
266             validator
267                 .check_if_valid(
268                     &generated.key_package.leaf_node,
269                     ValidationContext::Add(None),
270                 )
271                 .await
272                 .unwrap();
273 
274             validate_key_package_properties(
275                 &generated.key_package,
276                 protocol_version,
277                 &cipher_suite_provider,
278             )
279             .await
280             .unwrap();
281         }
282     }
283 
284     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_randomness()285     async fn test_randomness() {
286         for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
287             TestCryptoProvider::all_supported_cipher_suites()
288                 .into_iter()
289                 .map(move |cs| (p, cs))
290         }) {
291             let (signing_identity, signing_key) =
292                 get_test_signing_identity(cipher_suite, b"foo").await;
293 
294             let test_generator = KeyPackageGenerator {
295                 protocol_version,
296                 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
297                 signing_identity: &signing_identity,
298                 signing_key: &signing_key,
299             };
300 
301             let first_key_package = test_generator
302                 .generate(
303                     test_lifetime(),
304                     get_test_capabilities(),
305                     ExtensionList::default(),
306                     ExtensionList::default(),
307                 )
308                 .await
309                 .unwrap();
310 
311             for _ in 0..100 {
312                 let next_key_package = test_generator
313                     .generate(
314                         test_lifetime(),
315                         get_test_capabilities(),
316                         ExtensionList::default(),
317                         ExtensionList::default(),
318                     )
319                     .await
320                     .unwrap();
321 
322                 assert_ne!(
323                     first_key_package.key_package.hpke_init_key,
324                     next_key_package.key_package.hpke_init_key
325                 );
326 
327                 assert_ne!(
328                     first_key_package.key_package.leaf_node.public_key,
329                     next_key_package.key_package.leaf_node.public_key
330                 );
331             }
332         }
333     }
334 }
335