1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::error::IntoAnyError; 6 #[cfg(mls_build_async)] 7 use alloc::boxed::Box; 8 use alloc::vec::Vec; 9 use core::{ 10 fmt::{self, Debug}, 11 ops::Deref, 12 }; 13 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 14 use zeroize::Zeroizing; 15 16 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 17 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 18 /// Wrapper type that holds a pre-shared key value and zeroizes on drop. 19 pub struct PreSharedKey( 20 #[mls_codec(with = "mls_rs_codec::byte_vec")] 21 #[cfg_attr(feature = "serde", serde(with = "crate::zeroizing_serde"))] 22 Zeroizing<Vec<u8>>, 23 ); 24 25 impl Debug for PreSharedKey { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 27 crate::debug::pretty_bytes(&self.0) 28 .named("PreSharedKey") 29 .fmt(f) 30 } 31 } 32 33 impl PreSharedKey { 34 /// Create a new PreSharedKey. new(data: Vec<u8>) -> Self35 pub fn new(data: Vec<u8>) -> Self { 36 PreSharedKey(Zeroizing::new(data)) 37 } 38 39 /// Raw byte value. raw_value(&self) -> &[u8]40 pub fn raw_value(&self) -> &[u8] { 41 &self.0 42 } 43 } 44 45 impl From<Vec<u8>> for PreSharedKey { from(bytes: Vec<u8>) -> Self46 fn from(bytes: Vec<u8>) -> Self { 47 Self::new(bytes) 48 } 49 } 50 51 impl From<Zeroizing<Vec<u8>>> for PreSharedKey { from(bytes: Zeroizing<Vec<u8>>) -> Self52 fn from(bytes: Zeroizing<Vec<u8>>) -> Self { 53 Self(bytes) 54 } 55 } 56 57 impl AsRef<[u8]> for PreSharedKey { as_ref(&self) -> &[u8]58 fn as_ref(&self) -> &[u8] { 59 self.raw_value() 60 } 61 } 62 63 impl Deref for PreSharedKey { 64 type Target = [u8]; 65 deref(&self) -> &Self::Target66 fn deref(&self) -> &Self::Target { 67 self.raw_value() 68 } 69 } 70 71 #[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)] 72 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 73 #[cfg_attr( 74 all(feature = "ffi", not(test)), 75 safer_ffi_gen::ffi_type(clone, opaque) 76 )] 77 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 78 /// An external pre-shared key identifier. 79 pub struct ExternalPskId( 80 #[mls_codec(with = "mls_rs_codec::byte_vec")] 81 #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))] 82 Vec<u8>, 83 ); 84 85 impl Debug for ExternalPskId { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 87 crate::debug::pretty_bytes(&self.0) 88 .named("ExternalPskId") 89 .fmt(f) 90 } 91 } 92 93 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 94 impl ExternalPskId { new(id_data: Vec<u8>) -> Self95 pub fn new(id_data: Vec<u8>) -> Self { 96 Self(id_data) 97 } 98 } 99 100 impl AsRef<[u8]> for ExternalPskId { as_ref(&self) -> &[u8]101 fn as_ref(&self) -> &[u8] { 102 &self.0 103 } 104 } 105 106 impl Deref for ExternalPskId { 107 type Target = [u8]; 108 deref(&self) -> &Self::Target109 fn deref(&self) -> &Self::Target { 110 &self.0 111 } 112 } 113 114 impl From<Vec<u8>> for ExternalPskId { from(value: Vec<u8>) -> Self115 fn from(value: Vec<u8>) -> Self { 116 ExternalPskId(value) 117 } 118 } 119 120 /// Storage trait to maintain a set of pre-shared key values. 121 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 122 #[cfg_attr(mls_build_async, maybe_async::must_be_async)] 123 pub trait PreSharedKeyStorage: Send + Sync { 124 /// Error type that the underlying storage mechanism returns on internal 125 /// failure. 126 type Error: IntoAnyError; 127 128 /// Get a pre-shared key by [`ExternalPskId`](ExternalPskId). 129 /// 130 /// `None` should be returned if a pre-shared key can not be found for `id`. get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error>131 async fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error>; 132 133 /// Determines if a PSK is located within the store contains(&self, id: &ExternalPskId) -> Result<bool, Self::Error>134 async fn contains(&self, id: &ExternalPskId) -> Result<bool, Self::Error> { 135 self.get(id).await.map(|key| key.is_some()) 136 } 137 } 138