1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::cipher_suite::CipherSuite; 6 use crate::client::MlsError; 7 use crate::crypto::HpkePublicKey; 8 use crate::hash_reference::HashReference; 9 use crate::identity::SigningIdentity; 10 use crate::protocol_version::ProtocolVersion; 11 use crate::signer::Signable; 12 use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource}; 13 use crate::CipherSuiteProvider; 14 use alloc::vec::Vec; 15 use core::{ 16 fmt::{self, Debug}, 17 ops::Deref, 18 }; 19 use mls_rs_codec::MlsDecode; 20 use mls_rs_codec::MlsEncode; 21 use mls_rs_codec::MlsSize; 22 use mls_rs_core::extension::ExtensionList; 23 24 mod validator; 25 pub(crate) use validator::*; 26 27 pub(crate) mod generator; 28 pub(crate) use generator::*; 29 30 #[non_exhaustive] 31 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)] 32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 33 #[cfg_attr( 34 all(feature = "ffi", not(test)), 35 safer_ffi_gen::ffi_type(clone, opaque) 36 )] 37 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 38 pub struct KeyPackage { 39 pub version: ProtocolVersion, 40 pub cipher_suite: CipherSuite, 41 pub hpke_init_key: HpkePublicKey, 42 pub(crate) leaf_node: LeafNode, 43 pub extensions: ExtensionList, 44 #[mls_codec(with = "mls_rs_codec::byte_vec")] 45 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 46 pub signature: Vec<u8>, 47 } 48 49 impl Debug for KeyPackage { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 51 f.debug_struct("KeyPackage") 52 .field("version", &self.version) 53 .field("cipher_suite", &self.cipher_suite) 54 .field("hpke_init_key", &self.hpke_init_key) 55 .field("leaf_node", &self.leaf_node) 56 .field("extensions", &self.extensions) 57 .field( 58 "signature", 59 &mls_rs_core::debug::pretty_bytes(&self.signature), 60 ) 61 .finish() 62 } 63 } 64 65 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)] 66 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 67 #[cfg_attr( 68 all(feature = "ffi", not(test)), 69 safer_ffi_gen::ffi_type(clone, opaque) 70 )] 71 pub struct KeyPackageRef(HashReference); 72 73 impl Deref for KeyPackageRef { 74 type Target = [u8]; 75 deref(&self) -> &Self::Target76 fn deref(&self) -> &Self::Target { 77 &self.0 78 } 79 } 80 81 impl From<Vec<u8>> for KeyPackageRef { from(v: Vec<u8>) -> Self82 fn from(v: Vec<u8>) -> Self { 83 Self(HashReference::from(v)) 84 } 85 } 86 87 #[derive(MlsSize, MlsEncode)] 88 struct KeyPackageData<'a> { 89 pub version: ProtocolVersion, 90 pub cipher_suite: CipherSuite, 91 #[mls_codec(with = "mls_rs_codec::byte_vec")] 92 pub hpke_init_key: &'a HpkePublicKey, 93 pub leaf_node: &'a LeafNode, 94 pub extensions: &'a ExtensionList, 95 } 96 97 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 98 impl KeyPackage { 99 #[cfg(feature = "ffi")] version(&self) -> ProtocolVersion100 pub fn version(&self) -> ProtocolVersion { 101 self.version 102 } 103 104 #[cfg(feature = "ffi")] cipher_suite(&self) -> CipherSuite105 pub fn cipher_suite(&self) -> CipherSuite { 106 self.cipher_suite 107 } 108 signing_identity(&self) -> &SigningIdentity109 pub fn signing_identity(&self) -> &SigningIdentity { 110 &self.leaf_node.signing_identity 111 } 112 113 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)] 114 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] to_reference<CP: CipherSuiteProvider>( &self, cipher_suite_provider: &CP, ) -> Result<KeyPackageRef, MlsError>115 pub async fn to_reference<CP: CipherSuiteProvider>( 116 &self, 117 cipher_suite_provider: &CP, 118 ) -> Result<KeyPackageRef, MlsError> { 119 if cipher_suite_provider.cipher_suite() != self.cipher_suite { 120 return Err(MlsError::CipherSuiteMismatch); 121 } 122 123 Ok(KeyPackageRef( 124 HashReference::compute( 125 &self.mls_encode_to_vec()?, 126 b"MLS 1.0 KeyPackage Reference", 127 cipher_suite_provider, 128 ) 129 .await?, 130 )) 131 } 132 expiration(&self) -> Result<u64, MlsError>133 pub fn expiration(&self) -> Result<u64, MlsError> { 134 if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source { 135 Ok(lifetime.not_after) 136 } else { 137 Err(MlsError::InvalidLeafNodeSource) 138 } 139 } 140 } 141 142 impl<'a> Signable<'a> for KeyPackage { 143 const SIGN_LABEL: &'static str = "KeyPackageTBS"; 144 145 type SigningContext = (); 146 signature(&self) -> &[u8]147 fn signature(&self) -> &[u8] { 148 &self.signature 149 } 150 signable_content( &self, _context: &Self::SigningContext, ) -> Result<Vec<u8>, mls_rs_codec::Error>151 fn signable_content( 152 &self, 153 _context: &Self::SigningContext, 154 ) -> Result<Vec<u8>, mls_rs_codec::Error> { 155 KeyPackageData { 156 version: self.version, 157 cipher_suite: self.cipher_suite, 158 hpke_init_key: &self.hpke_init_key, 159 leaf_node: &self.leaf_node, 160 extensions: &self.extensions, 161 } 162 .mls_encode_to_vec() 163 } 164 write_signature(&mut self, signature: Vec<u8>)165 fn write_signature(&mut self, signature: Vec<u8>) { 166 self.signature = signature 167 } 168 } 169 170 #[cfg(test)] 171 pub(crate) mod test_utils { 172 use super::*; 173 use crate::{ 174 crypto::test_utils::test_cipher_suite_provider, 175 group::framing::MlsMessagePayload, 176 identity::test_utils::get_test_signing_identity, 177 tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime}, 178 MlsMessage, 179 }; 180 181 use mls_rs_core::crypto::SignatureSecretKey; 182 183 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_key_package( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> KeyPackage184 pub(crate) async fn test_key_package( 185 protocol_version: ProtocolVersion, 186 cipher_suite: CipherSuite, 187 id: &str, 188 ) -> KeyPackage { 189 test_key_package_with_signer(protocol_version, cipher_suite, id) 190 .await 191 .0 192 } 193 194 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_key_package_with_signer( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> (KeyPackage, SignatureSecretKey)195 pub(crate) async fn test_key_package_with_signer( 196 protocol_version: ProtocolVersion, 197 cipher_suite: CipherSuite, 198 id: &str, 199 ) -> (KeyPackage, SignatureSecretKey) { 200 let (signing_identity, secret_key) = 201 get_test_signing_identity(cipher_suite, id.as_bytes()).await; 202 203 let generator = KeyPackageGenerator { 204 protocol_version, 205 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite), 206 signing_identity: &signing_identity, 207 signing_key: &secret_key, 208 }; 209 210 let key_package = generator 211 .generate( 212 Lifetime::years(1).unwrap(), 213 get_test_capabilities(), 214 ExtensionList::default(), 215 ExtensionList::default(), 216 ) 217 .await 218 .unwrap() 219 .key_package; 220 221 (key_package, secret_key) 222 } 223 224 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_key_package_message( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, id: &str, ) -> MlsMessage225 pub(crate) async fn test_key_package_message( 226 protocol_version: ProtocolVersion, 227 cipher_suite: CipherSuite, 228 id: &str, 229 ) -> MlsMessage { 230 MlsMessage::new( 231 protocol_version, 232 MlsMessagePayload::KeyPackage( 233 test_key_package(protocol_version, cipher_suite, id).await, 234 ), 235 ) 236 } 237 } 238 239 #[cfg(test)] 240 mod tests { 241 use crate::{ 242 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, 243 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider}, 244 }; 245 246 use super::{test_utils::test_key_package, *}; 247 use alloc::format; 248 use assert_matches::assert_matches; 249 250 #[derive(serde::Deserialize, serde::Serialize)] 251 struct TestCase { 252 cipher_suite: u16, 253 #[serde(with = "hex::serde")] 254 input: Vec<u8>, 255 #[serde(with = "hex::serde")] 256 output: Vec<u8>, 257 } 258 259 impl TestCase { 260 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 261 #[cfg_attr(coverage_nightly, coverage(off))] generate() -> Vec<TestCase>262 async fn generate() -> Vec<TestCase> { 263 let mut test_cases = Vec::new(); 264 265 for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all() 266 .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs))) 267 .enumerate() 268 { 269 let pkg = 270 test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await; 271 272 let pkg_ref = pkg 273 .to_reference(&test_cipher_suite_provider(cipher_suite)) 274 .await 275 .unwrap(); 276 277 let case = TestCase { 278 cipher_suite: cipher_suite.into(), 279 input: pkg.mls_encode_to_vec().unwrap(), 280 output: pkg_ref.to_vec(), 281 }; 282 283 test_cases.push(case); 284 } 285 286 test_cases 287 } 288 } 289 290 #[cfg(mls_build_async)] load_test_cases() -> Vec<TestCase>291 async fn load_test_cases() -> Vec<TestCase> { 292 load_test_case_json!(key_package_ref, TestCase::generate().await) 293 } 294 295 #[cfg(not(mls_build_async))] load_test_cases() -> Vec<TestCase>296 fn load_test_cases() -> Vec<TestCase> { 297 load_test_case_json!(key_package_ref, TestCase::generate()) 298 } 299 300 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_key_package_ref()301 async fn test_key_package_ref() { 302 let cases = load_test_cases().await; 303 304 for one_case in cases { 305 let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else { 306 continue; 307 }; 308 309 let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap(); 310 311 let key_package_ref = key_package.to_reference(&provider).await.unwrap(); 312 313 let expected_out = KeyPackageRef::from(one_case.output); 314 assert_eq!(expected_out, key_package_ref); 315 } 316 } 317 318 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] key_package_ref_fails_invalid_cipher_suite()319 async fn key_package_ref_fails_invalid_cipher_suite() { 320 let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await; 321 322 for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) { 323 if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) { 324 let res = key_package.to_reference(&cs).await; 325 326 assert_matches!(res, Err(MlsError::CipherSuiteMismatch)); 327 } 328 } 329 } 330 } 331