1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 7 #[cfg(any(test, feature = "external_client"))] 8 use alloc::vec; 9 10 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 11 12 #[cfg(any(test, feature = "external_client"))] 13 use mls_rs_core::psk::PreSharedKeyStorage; 14 15 #[cfg(any(test, feature = "external_client"))] 16 use core::convert::Infallible; 17 use core::fmt::{self, Debug}; 18 19 #[cfg(feature = "psk")] 20 use crate::{client::MlsError, CipherSuiteProvider}; 21 22 #[cfg(feature = "psk")] 23 use mls_rs_core::error::IntoAnyError; 24 25 #[cfg(feature = "psk")] 26 pub(crate) mod resolver; 27 pub(crate) mod secret; 28 29 pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey}; 30 31 #[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)] 32 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 33 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 34 pub(crate) struct PreSharedKeyID { 35 pub key_id: JustPreSharedKeyID, 36 pub psk_nonce: PskNonce, 37 } 38 39 impl PreSharedKeyID { 40 #[cfg(feature = "psk")] new<P: CipherSuiteProvider>( key_id: JustPreSharedKeyID, cs: &P, ) -> Result<Self, MlsError>41 pub(crate) fn new<P: CipherSuiteProvider>( 42 key_id: JustPreSharedKeyID, 43 cs: &P, 44 ) -> Result<Self, MlsError> { 45 Ok(Self { 46 key_id, 47 psk_nonce: PskNonce::random(cs) 48 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?, 49 }) 50 } 51 } 52 53 #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 54 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 55 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 56 #[repr(u8)] 57 pub(crate) enum JustPreSharedKeyID { 58 External(ExternalPskId) = 1u8, 59 Resumption(ResumptionPsk) = 2u8, 60 } 61 62 #[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 63 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 64 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 65 pub(crate) struct PskGroupId( 66 #[mls_codec(with = "mls_rs_codec::byte_vec")] 67 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 68 pub Vec<u8>, 69 ); 70 71 impl Debug for PskGroupId { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 73 mls_rs_core::debug::pretty_bytes(&self.0) 74 .named("PskGroupId") 75 .fmt(f) 76 } 77 } 78 79 #[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)] 80 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 81 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 82 pub(crate) struct PskNonce( 83 #[mls_codec(with = "mls_rs_codec::byte_vec")] 84 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 85 pub Vec<u8>, 86 ); 87 88 impl Debug for PskNonce { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 90 mls_rs_core::debug::pretty_bytes(&self.0) 91 .named("PskNonce") 92 .fmt(f) 93 } 94 } 95 96 #[cfg(feature = "psk")] 97 impl PskNonce { random<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> Result<Self, <P as CipherSuiteProvider>::Error>98 pub fn random<P: CipherSuiteProvider>( 99 cipher_suite_provider: &P, 100 ) -> Result<Self, <P as CipherSuiteProvider>::Error> { 101 Ok(Self(cipher_suite_provider.random_bytes_vec( 102 cipher_suite_provider.kdf_extract_size(), 103 )?)) 104 } 105 } 106 107 #[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 108 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 109 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 110 pub(crate) struct ResumptionPsk { 111 pub usage: ResumptionPSKUsage, 112 pub psk_group_id: PskGroupId, 113 pub psk_epoch: u64, 114 } 115 116 #[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd, MlsSize, MlsEncode, MlsDecode)] 117 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 118 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 119 #[repr(u8)] 120 pub(crate) enum ResumptionPSKUsage { 121 Application = 1u8, 122 Reinit = 2u8, 123 Branch = 3u8, 124 } 125 126 #[cfg(feature = "psk")] 127 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)] 128 struct PSKLabel<'a> { 129 id: &'a PreSharedKeyID, 130 index: u16, 131 count: u16, 132 } 133 134 #[cfg(any(test, feature = "external_client"))] 135 #[derive(Clone, Copy, Debug)] 136 pub(crate) struct AlwaysFoundPskStorage; 137 138 #[cfg(any(test, feature = "external_client"))] 139 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 140 #[cfg_attr(mls_build_async, maybe_async::must_be_async)] 141 impl PreSharedKeyStorage for AlwaysFoundPskStorage { 142 type Error = Infallible; 143 get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error>144 async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> { 145 Ok(Some(vec![].into())) 146 } 147 } 148 149 #[cfg(feature = "psk")] 150 #[cfg(test)] 151 pub(crate) mod test_utils { 152 use crate::crypto::test_utils::test_cipher_suite_provider; 153 154 use super::PskNonce; 155 use mls_rs_core::crypto::CipherSuite; 156 157 #[cfg(not(mls_build_async))] 158 use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId}; 159 160 #[cfg_attr(coverage_nightly, coverage(off))] 161 #[cfg(not(mls_build_async))] make_external_psk_id<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> ExternalPskId162 pub(crate) fn make_external_psk_id<P: CipherSuiteProvider>( 163 cipher_suite_provider: &P, 164 ) -> ExternalPskId { 165 ExternalPskId::new( 166 cipher_suite_provider 167 .random_bytes_vec(cipher_suite_provider.kdf_extract_size()) 168 .unwrap(), 169 ) 170 } 171 make_nonce(cipher_suite: CipherSuite) -> PskNonce172 pub(crate) fn make_nonce(cipher_suite: CipherSuite) -> PskNonce { 173 PskNonce::random(&test_cipher_suite_provider(cipher_suite)).unwrap() 174 } 175 } 176 177 #[cfg(feature = "psk")] 178 #[cfg(test)] 179 mod tests { 180 use crate::crypto::test_utils::TestCryptoProvider; 181 use core::iter; 182 183 #[cfg(target_arch = "wasm32")] 184 use wasm_bindgen_test::wasm_bindgen_test as test; 185 186 use super::test_utils::make_nonce; 187 188 #[test] random_generation_of_nonces_is_random()189 fn random_generation_of_nonces_is_random() { 190 let good = TestCryptoProvider::all_supported_cipher_suites() 191 .into_iter() 192 .all(|cipher_suite| { 193 let nonce = make_nonce(cipher_suite); 194 iter::repeat_with(|| make_nonce(cipher_suite)) 195 .take(1000) 196 .all(|other| other != nonce) 197 }); 198 199 assert!(good); 200 } 201 } 202