1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec; 6 use alloc::vec::Vec; 7 use mls_rs_codec::{MlsDecode, MlsEncode}; 8 use mls_rs_core::{error::IntoAnyError, key_package::KeyPackageData}; 9 10 use crate::client::MlsError; 11 use crate::{ 12 crypto::{HpkeSecretKey, SignatureSecretKey}, 13 group::framing::MlsMessagePayload, 14 identity::SigningIdentity, 15 protocol_version::ProtocolVersion, 16 signer::Signable, 17 tree_kem::{ 18 leaf_node::{ConfigProperties, LeafNode}, 19 Capabilities, Lifetime, 20 }, 21 CipherSuiteProvider, ExtensionList, MlsMessage, 22 }; 23 24 use super::{KeyPackage, KeyPackageRef}; 25 26 #[derive(Clone, Debug)] 27 pub struct KeyPackageGenerator<'a, CP> 28 where 29 CP: CipherSuiteProvider, 30 { 31 pub protocol_version: ProtocolVersion, 32 pub cipher_suite_provider: &'a CP, 33 pub signing_identity: &'a SigningIdentity, 34 pub signing_key: &'a SignatureSecretKey, 35 } 36 37 #[derive(Clone, Debug)] 38 pub struct KeyPackageGeneration { 39 pub(crate) reference: KeyPackageRef, 40 pub(crate) key_package: KeyPackage, 41 pub(crate) init_secret_key: HpkeSecretKey, 42 pub(crate) leaf_node_secret_key: HpkeSecretKey, 43 } 44 45 impl KeyPackageGeneration { to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError>46 pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> { 47 let id = self.reference.to_vec(); 48 49 let data = KeyPackageData::new( 50 self.key_package.mls_encode_to_vec()?, 51 self.init_secret_key.clone(), 52 self.leaf_node_secret_key.clone(), 53 self.key_package.expiration()?, 54 ); 55 56 Ok((id, data)) 57 } 58 from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError>59 pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> { 60 Ok(KeyPackageGeneration { 61 reference: KeyPackageRef::from(id), 62 key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?, 63 init_secret_key: data.init_key, 64 leaf_node_secret_key: data.leaf_node_key, 65 }) 66 } 67 key_package_message(&self) -> MlsMessage68 pub fn key_package_message(&self) -> MlsMessage { 69 MlsMessage::new( 70 self.key_package.version, 71 MlsMessagePayload::KeyPackage(self.key_package.clone()), 72 ) 73 } 74 } 75 76 impl<'a, CP> KeyPackageGenerator<'a, CP> 77 where 78 CP: CipherSuiteProvider, 79 { 80 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] sign(&self, package: &mut KeyPackage) -> Result<(), MlsError>81 pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> { 82 package 83 .sign(self.cipher_suite_provider, self.signing_key, &()) 84 .await 85 } 86 87 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] generate( &self, lifetime: Lifetime, capabilities: Capabilities, key_package_extensions: ExtensionList, leaf_node_extensions: ExtensionList, ) -> Result<KeyPackageGeneration, MlsError>88 pub async fn generate( 89 &self, 90 lifetime: Lifetime, 91 capabilities: Capabilities, 92 key_package_extensions: ExtensionList, 93 leaf_node_extensions: ExtensionList, 94 ) -> Result<KeyPackageGeneration, MlsError> { 95 let (init_secret_key, public_init) = self 96 .cipher_suite_provider 97 .kem_generate() 98 .await 99 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 100 101 let properties = ConfigProperties { 102 capabilities, 103 extensions: leaf_node_extensions, 104 }; 105 106 let (leaf_node, leaf_node_secret) = LeafNode::generate( 107 self.cipher_suite_provider, 108 properties, 109 self.signing_identity.clone(), 110 self.signing_key, 111 lifetime, 112 ) 113 .await?; 114 115 let mut package = KeyPackage { 116 version: self.protocol_version, 117 cipher_suite: self.cipher_suite_provider.cipher_suite(), 118 hpke_init_key: public_init, 119 leaf_node, 120 extensions: key_package_extensions, 121 signature: vec![], 122 }; 123 124 package.grease(self.cipher_suite_provider)?; 125 126 self.sign(&mut package).await?; 127 128 let reference = package.to_reference(self.cipher_suite_provider).await?; 129 130 Ok(KeyPackageGeneration { 131 key_package: package, 132 init_secret_key, 133 leaf_node_secret_key: leaf_node_secret, 134 reference, 135 }) 136 } 137 } 138 139 #[cfg(test)] 140 mod tests { 141 use assert_matches::assert_matches; 142 use mls_rs_core::crypto::CipherSuiteProvider; 143 144 use crate::{ 145 crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, 146 extension::test_utils::TestExtension, 147 group::test_utils::random_bytes, 148 identity::basic::BasicIdentityProvider, 149 identity::test_utils::get_test_signing_identity, 150 key_package::validate_key_package_properties, 151 protocol_version::ProtocolVersion, 152 tree_kem::{ 153 leaf_node::{test_utils::get_test_capabilities, LeafNodeSource}, 154 leaf_node_validator::{LeafNodeValidator, ValidationContext}, 155 Lifetime, 156 }, 157 ExtensionList, 158 }; 159 160 use super::KeyPackageGenerator; 161 test_key_package_ext(val: u8) -> ExtensionList162 fn test_key_package_ext(val: u8) -> ExtensionList { 163 let mut ext_list = ExtensionList::new(); 164 ext_list.set_from(TestExtension::from(val)).unwrap(); 165 ext_list 166 } 167 test_leaf_node_ext(val: u8) -> ExtensionList168 fn test_leaf_node_ext(val: u8) -> ExtensionList { 169 let mut ext_list = ExtensionList::new(); 170 ext_list.set_from(TestExtension::from(val)).unwrap(); 171 ext_list 172 } 173 test_lifetime() -> Lifetime174 fn test_lifetime() -> Lifetime { 175 Lifetime::years(1).unwrap() 176 } 177 178 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_key_generation()179 async fn test_key_generation() { 180 for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| { 181 TestCryptoProvider::all_supported_cipher_suites() 182 .into_iter() 183 .map(move |cs| (p, cs)) 184 }) { 185 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); 186 187 let (signing_identity, signing_key) = 188 get_test_signing_identity(cipher_suite, b"foo").await; 189 190 let key_package_ext = test_key_package_ext(32); 191 let leaf_node_ext = test_leaf_node_ext(42); 192 let lifetime = test_lifetime(); 193 194 let test_generator = KeyPackageGenerator { 195 protocol_version, 196 cipher_suite_provider: &cipher_suite_provider, 197 signing_identity: &signing_identity, 198 signing_key: &signing_key, 199 }; 200 201 let mut capabilities = get_test_capabilities(); 202 capabilities.extensions.push(42.into()); 203 capabilities.extensions.push(43.into()); 204 capabilities.extensions.push(32.into()); 205 206 let generated = test_generator 207 .generate( 208 lifetime.clone(), 209 capabilities.clone(), 210 key_package_ext.clone(), 211 leaf_node_ext.clone(), 212 ) 213 .await 214 .unwrap(); 215 216 assert_matches!(generated.key_package.leaf_node.leaf_node_source, 217 LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime); 218 219 assert_eq!( 220 generated.key_package.leaf_node.ungreased_capabilities(), 221 capabilities 222 ); 223 224 assert_eq!( 225 generated.key_package.leaf_node.ungreased_extensions(), 226 leaf_node_ext 227 ); 228 229 assert_eq!( 230 generated.key_package.ungreased_extensions(), 231 key_package_ext 232 ); 233 234 assert_ne!( 235 generated.key_package.hpke_init_key.as_ref(), 236 generated.key_package.leaf_node.public_key.as_ref() 237 ); 238 239 assert_eq!(generated.key_package.cipher_suite, cipher_suite); 240 assert_eq!(generated.key_package.version, protocol_version); 241 242 // Verify that the hpke key pair generated will work 243 let test_data = random_bytes(32); 244 245 let sealed = cipher_suite_provider 246 .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data) 247 .await 248 .unwrap(); 249 250 let opened = cipher_suite_provider 251 .hpke_open( 252 &sealed, 253 &generated.init_secret_key, 254 &generated.key_package.hpke_init_key, 255 &[], 256 None, 257 ) 258 .await 259 .unwrap(); 260 261 assert_eq!(opened, test_data); 262 263 let validator = 264 LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None); 265 266 validator 267 .check_if_valid( 268 &generated.key_package.leaf_node, 269 ValidationContext::Add(None), 270 ) 271 .await 272 .unwrap(); 273 274 validate_key_package_properties( 275 &generated.key_package, 276 protocol_version, 277 &cipher_suite_provider, 278 ) 279 .await 280 .unwrap(); 281 } 282 } 283 284 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_randomness()285 async fn test_randomness() { 286 for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| { 287 TestCryptoProvider::all_supported_cipher_suites() 288 .into_iter() 289 .map(move |cs| (p, cs)) 290 }) { 291 let (signing_identity, signing_key) = 292 get_test_signing_identity(cipher_suite, b"foo").await; 293 294 let test_generator = KeyPackageGenerator { 295 protocol_version, 296 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite), 297 signing_identity: &signing_identity, 298 signing_key: &signing_key, 299 }; 300 301 let first_key_package = test_generator 302 .generate( 303 test_lifetime(), 304 get_test_capabilities(), 305 ExtensionList::default(), 306 ExtensionList::default(), 307 ) 308 .await 309 .unwrap(); 310 311 for _ in 0..100 { 312 let next_key_package = test_generator 313 .generate( 314 test_lifetime(), 315 get_test_capabilities(), 316 ExtensionList::default(), 317 ExtensionList::default(), 318 ) 319 .await 320 .unwrap(); 321 322 assert_ne!( 323 first_key_package.key_package.hpke_init_key, 324 next_key_package.key_package.hpke_init_key 325 ); 326 327 assert_ne!( 328 first_key_package.key_package.leaf_node.public_key, 329 next_key_package.key_package.leaf_node.public_key 330 ); 331 } 332 } 333 } 334 } 335