1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::collections::VecDeque; 6 7 #[cfg(target_has_atomic = "ptr")] 8 use alloc::sync::Arc; 9 10 #[cfg(mls_build_async)] 11 use alloc::boxed::Box; 12 use alloc::vec::Vec; 13 use core::{ 14 convert::Infallible, 15 fmt::{self, Debug}, 16 }; 17 use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; 18 #[cfg(not(target_has_atomic = "ptr"))] 19 use portable_atomic_util::Arc; 20 21 use crate::{ 22 client::MlsError, 23 map::{LargeMap, LargeMapEntry}, 24 }; 25 26 #[cfg(feature = "std")] 27 use std::sync::{Mutex, MutexGuard}; 28 29 #[cfg(not(feature = "std"))] 30 use spin::{Mutex, MutexGuard}; 31 32 pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3; 33 34 #[derive(Clone)] 35 pub(crate) struct InMemoryGroupData { 36 pub(crate) state_data: Vec<u8>, 37 pub(crate) epoch_data: VecDeque<EpochRecord>, 38 } 39 40 impl Debug for InMemoryGroupData { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 42 f.debug_struct("InMemoryGroupData") 43 .field( 44 "state_data", 45 &mls_rs_core::debug::pretty_bytes(&self.state_data), 46 ) 47 .field("epoch_data", &self.epoch_data) 48 .finish() 49 } 50 } 51 52 impl InMemoryGroupData { new(state_data: Vec<u8>) -> InMemoryGroupData53 pub fn new(state_data: Vec<u8>) -> InMemoryGroupData { 54 InMemoryGroupData { 55 state_data, 56 epoch_data: Default::default(), 57 } 58 } 59 get_epoch_data_index(&self, epoch_id: u64) -> Option<u64>60 fn get_epoch_data_index(&self, epoch_id: u64) -> Option<u64> { 61 self.epoch_data 62 .front() 63 .and_then(|e| epoch_id.checked_sub(e.id)) 64 } 65 get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord>66 pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> { 67 self.get_epoch_data_index(epoch_id) 68 .and_then(|i| self.epoch_data.get(i as usize)) 69 } 70 get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord>71 pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> { 72 self.get_epoch_data_index(epoch_id) 73 .and_then(|i| self.epoch_data.get_mut(i as usize)) 74 } 75 insert_epoch(&mut self, epoch: EpochRecord)76 pub fn insert_epoch(&mut self, epoch: EpochRecord) { 77 self.epoch_data.push_back(epoch) 78 } 79 80 // This function does not fail if an update can't be made. If the epoch 81 // is not in the store, then it can no longer be accessed by future 82 // get_epoch calls and is no longer relevant. update_epoch(&mut self, epoch: EpochRecord)83 pub fn update_epoch(&mut self, epoch: EpochRecord) { 84 if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) { 85 *existing_epoch = epoch 86 } 87 } 88 trim_epochs(&mut self, max_epoch_retention: usize)89 pub fn trim_epochs(&mut self, max_epoch_retention: usize) { 90 while self.epoch_data.len() > max_epoch_retention { 91 self.epoch_data.pop_front(); 92 } 93 } 94 } 95 96 #[derive(Clone)] 97 /// In memory group state storage backed by a HashMap. 98 /// 99 /// All clones of an instance of this type share the same underlying HashMap. 100 pub struct InMemoryGroupStateStorage { 101 pub(crate) inner: Arc<Mutex<LargeMap<Vec<u8>, InMemoryGroupData>>>, 102 pub(crate) max_epoch_retention: usize, 103 } 104 105 impl Debug for InMemoryGroupStateStorage { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 107 f.debug_struct("InMemoryGroupStateStorage") 108 .field( 109 "inner", 110 &mls_rs_core::debug::pretty_with(|f| { 111 f.debug_map() 112 .entries( 113 self.lock() 114 .iter() 115 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)), 116 ) 117 .finish() 118 }), 119 ) 120 .field("max_epoch_retention", &self.max_epoch_retention) 121 .finish() 122 } 123 } 124 125 impl InMemoryGroupStateStorage { 126 /// Create an empty group state storage. new() -> Self127 pub fn new() -> Self { 128 Self { 129 inner: Default::default(), 130 max_epoch_retention: DEFAULT_EPOCH_RETENTION_LIMIT, 131 } 132 } 133 with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError>134 pub fn with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError> { 135 (max_epoch_retention > 0) 136 .then_some(()) 137 .ok_or(MlsError::NonZeroRetentionRequired)?; 138 139 Ok(Self { 140 inner: self.inner, 141 max_epoch_retention, 142 }) 143 } 144 145 /// Get the set of unique group ids that have data stored. stored_groups(&self) -> Vec<Vec<u8>>146 pub fn stored_groups(&self) -> Vec<Vec<u8>> { 147 self.lock().keys().cloned().collect() 148 } 149 150 /// Delete all data corresponding to `group_id`. delete_group(&self, group_id: &[u8])151 pub fn delete_group(&self, group_id: &[u8]) { 152 self.lock().remove(group_id); 153 } 154 lock(&self) -> MutexGuard<'_, LargeMap<Vec<u8>, InMemoryGroupData>>155 fn lock(&self) -> MutexGuard<'_, LargeMap<Vec<u8>, InMemoryGroupData>> { 156 #[cfg(feature = "std")] 157 return self.inner.lock().unwrap(); 158 159 #[cfg(not(feature = "std"))] 160 return self.inner.lock(); 161 } 162 } 163 164 impl Default for InMemoryGroupStateStorage { default() -> Self165 fn default() -> Self { 166 Self::new() 167 } 168 } 169 170 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 171 #[cfg_attr(mls_build_async, maybe_async::must_be_async)] 172 impl GroupStateStorage for InMemoryGroupStateStorage { 173 type Error = Infallible; 174 max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error>175 async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> { 176 Ok(self 177 .lock() 178 .get(group_id) 179 .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id))) 180 } 181 state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>182 async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> { 183 Ok(self 184 .lock() 185 .get(group_id) 186 .map(|data| data.state_data.clone())) 187 } 188 epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error>189 async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> { 190 Ok(self 191 .lock() 192 .get(group_id) 193 .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone()))) 194 } 195 write( &mut self, state: GroupState, epoch_inserts: Vec<EpochRecord>, epoch_updates: Vec<EpochRecord>, ) -> Result<(), Self::Error>196 async fn write( 197 &mut self, 198 state: GroupState, 199 epoch_inserts: Vec<EpochRecord>, 200 epoch_updates: Vec<EpochRecord>, 201 ) -> Result<(), Self::Error> { 202 let mut group_map = self.lock(); 203 204 let group_data = match group_map.entry(state.id) { 205 LargeMapEntry::Occupied(entry) => { 206 let data = entry.into_mut(); 207 data.state_data = state.data; 208 data 209 } 210 LargeMapEntry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)), 211 }; 212 213 epoch_inserts 214 .into_iter() 215 .for_each(|e| group_data.insert_epoch(e)); 216 217 epoch_updates 218 .into_iter() 219 .for_each(|e| group_data.update_epoch(e)); 220 221 group_data.trim_epochs(self.max_epoch_retention); 222 223 Ok(()) 224 } 225 } 226 227 #[cfg(all(test, feature = "prior_epoch"))] 228 mod tests { 229 use alloc::{format, vec, vec::Vec}; 230 use assert_matches::assert_matches; 231 232 use super::{InMemoryGroupData, InMemoryGroupStateStorage}; 233 use crate::{client::MlsError, group::test_utils::TEST_GROUP}; 234 235 use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage}; 236 237 impl InMemoryGroupStateStorage { test_data(&self) -> InMemoryGroupData238 fn test_data(&self) -> InMemoryGroupData { 239 self.lock().get(TEST_GROUP).unwrap().clone() 240 } 241 } 242 test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError>243 fn test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError> { 244 InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit) 245 } 246 test_epoch(epoch_id: u64) -> EpochRecord247 fn test_epoch(epoch_id: u64) -> EpochRecord { 248 EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec()) 249 } 250 test_snapshot(epoch_id: u64) -> GroupState251 fn test_snapshot(epoch_id: u64) -> GroupState { 252 GroupState { 253 id: TEST_GROUP.into(), 254 data: format!("snapshot {epoch_id}").as_bytes().to_vec(), 255 } 256 } 257 258 #[test] test_zero_max_retention()259 fn test_zero_max_retention() { 260 assert_matches!(test_storage(0), Err(MlsError::NonZeroRetentionRequired)) 261 } 262 263 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] existing_storage_can_have_larger_epoch_count()264 async fn existing_storage_can_have_larger_epoch_count() { 265 let mut storage = test_storage(2).unwrap(); 266 267 let epoch_inserts = vec![test_epoch(0), test_epoch(1)]; 268 269 storage 270 .write(test_snapshot(0), epoch_inserts, Vec::new()) 271 .await 272 .unwrap(); 273 274 assert_eq!(storage.test_data().epoch_data.len(), 2); 275 276 storage.max_epoch_retention = 4; 277 278 let epoch_inserts = vec![test_epoch(3), test_epoch(4)]; 279 280 storage 281 .write(test_snapshot(1), epoch_inserts, Vec::new()) 282 .await 283 .unwrap(); 284 285 assert_eq!(storage.test_data().epoch_data.len(), 4); 286 } 287 288 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] existing_storage_can_have_smaller_epoch_count()289 async fn existing_storage_can_have_smaller_epoch_count() { 290 let mut storage = test_storage(4).unwrap(); 291 292 let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)]; 293 294 storage 295 .write(test_snapshot(1), epoch_inserts, Vec::new()) 296 .await 297 .unwrap(); 298 299 assert_eq!(storage.test_data().epoch_data.len(), 4); 300 301 storage.max_epoch_retention = 2; 302 303 let epoch_inserts = vec![test_epoch(5)]; 304 305 storage 306 .write(test_snapshot(1), epoch_inserts, Vec::new()) 307 .await 308 .unwrap(); 309 310 assert_eq!(storage.test_data().epoch_data.len(), 2); 311 } 312 313 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] epoch_insert_over_limit()314 async fn epoch_insert_over_limit() { 315 test_epoch_insert_over_limit(false).await 316 } 317 318 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] epoch_insert_over_limit_with_update()319 async fn epoch_insert_over_limit_with_update() { 320 test_epoch_insert_over_limit(true).await 321 } 322 323 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_epoch_insert_over_limit(with_update: bool)324 async fn test_epoch_insert_over_limit(with_update: bool) { 325 let mut storage = test_storage(1).unwrap(); 326 327 let mut epoch_inserts = vec![test_epoch(0), test_epoch(1)]; 328 let updates = with_update 329 .then_some(vec![test_epoch(0)]) 330 .unwrap_or_default(); 331 let snapshot = test_snapshot(1); 332 333 storage 334 .write(snapshot.clone(), epoch_inserts.clone(), updates) 335 .await 336 .unwrap(); 337 338 let stored = storage.test_data(); 339 340 assert_eq!(stored.state_data, snapshot.data); 341 assert_eq!(stored.epoch_data.len(), 1); 342 343 let expected = epoch_inserts.pop().unwrap(); 344 assert_eq!(stored.epoch_data[0], expected); 345 } 346 } 347