// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::client::MlsError; use crate::{group::PriorEpoch, key_package::KeyPackageRef}; use alloc::collections::VecDeque; use alloc::vec::Vec; use core::fmt::{self, Debug}; use mls_rs_codec::{MlsDecode, MlsEncode}; use mls_rs_core::group::{EpochRecord, GroupState}; use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage}; use super::snapshot::Snapshot; #[cfg(feature = "psk")] use crate::group::ResumptionPsk; #[cfg(feature = "psk")] use mls_rs_core::psk::PreSharedKey; /// A set of changes to apply to a GroupStateStorage implementation. These changes MUST /// be made in a single transaction to avoid creating invalid states. #[derive(Default, Clone, Debug)] struct EpochStorageCommit { pub(crate) inserts: VecDeque, pub(crate) updates: Vec, } #[derive(Clone)] pub(crate) struct GroupStateRepository where S: GroupStateStorage, K: KeyPackageStorage, { pending_commit: EpochStorageCommit, pending_key_package_removal: Option, group_id: Vec, storage: S, key_package_repo: K, } impl Debug for GroupStateRepository where S: GroupStateStorage + Debug, K: KeyPackageStorage + Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("GroupStateRepository") .field("pending_commit", &self.pending_commit) .field( "pending_key_package_removal", &self.pending_key_package_removal, ) .field( "group_id", &mls_rs_core::debug::pretty_group_id(&self.group_id), ) .field("storage", &self.storage) .field("key_package_repo", &self.key_package_repo) .finish() } } impl GroupStateRepository where S: GroupStateStorage, K: KeyPackageStorage, { pub fn new( group_id: Vec, storage: S, key_package_repo: K, // Set to `None` if restoring from snapshot; set to `Some` when joining a group. key_package_to_remove: Option, ) -> Result, MlsError> { Ok(GroupStateRepository { group_id, storage, pending_key_package_removal: key_package_to_remove, pending_commit: Default::default(), key_package_repo, }) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn find_max_id(&self) -> Result, MlsError> { if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) { Ok(Some(max)) } else { self.storage .max_epoch_id(&self.group_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error())) } } #[cfg(feature = "psk")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn resumption_secret( &self, psk_id: &ResumptionPsk, ) -> Result, MlsError> { // Search the local inserts cache if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) { if psk_id.psk_epoch >= min { return Ok(self .pending_commit .inserts .get((psk_id.psk_epoch - min) as usize) .map(|e| e.secrets.resumption_secret.clone())); } } // Search the local updates cache let maybe_pending = self.find_pending(psk_id.psk_epoch); if let Some(pending) = maybe_pending { return Ok(Some( self.pending_commit.updates[pending] .secrets .resumption_secret .clone(), )); } // Search the stored cache self.storage .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret)) .transpose() } #[cfg(feature = "private_message")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn get_epoch_mut( &mut self, epoch_id: u64, ) -> Result, MlsError> { // Search the local inserts cache if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) { if epoch_id >= min { return Ok(self .pending_commit .inserts .get_mut((epoch_id - min) as usize)); } } // Look in the cached updates map, and if not found look in disk storage // and insert into the updates map for future caching match self.find_pending(epoch_id) { Some(i) => self.pending_commit.updates.get_mut(i).map(Ok), None => self .storage .epoch(&self.group_id, epoch_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .and_then(|epoch| { PriorEpoch::mls_decode(&mut &*epoch) .map(|epoch| { self.pending_commit.updates.push(epoch); self.pending_commit.updates.last_mut() }) .transpose() }), } .transpose() .map_err(Into::into) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> { if epoch.group_id() != self.group_id { return Err(MlsError::GroupIdMismatch); } let epoch_id = epoch.epoch_id(); if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) { if epoch_id != expected_id { return Err(MlsError::InvalidEpoch); } } self.pending_commit.inserts.push_back(epoch); Ok(()) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> { let inserts = self .pending_commit .inserts .iter() .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?))) .collect::>()?; let updates = self .pending_commit .updates .iter() .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?))) .collect::>()?; let group_state = GroupState { data: group_snapshot.mls_encode_to_vec()?, id: group_snapshot.state.context.group_id, }; self.storage .write(group_state, inserts, updates) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?; if let Some(ref key_package_ref) = self.pending_key_package_removal { self.key_package_repo .delete(key_package_ref) .await .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?; } self.pending_commit.inserts.clear(); self.pending_commit.updates.clear(); Ok(()) } #[cfg(any(feature = "psk", feature = "private_message"))] fn find_pending(&self, epoch_id: u64) -> Option { self.pending_commit .updates .iter() .position(|ep| ep.context.epoch == epoch_id) } } #[cfg(test)] mod tests { use alloc::vec; use mls_rs_codec::MlsEncode; use crate::{ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, group::{ epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret}, test_utils::{random_bytes, test_member, TEST_GROUP}, PskGroupId, ResumptionPSKUsage, }, storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage}, }; use super::*; fn test_group_state_repo( retention_limit: usize, ) -> GroupStateRepository { GroupStateRepository::new( TEST_GROUP.to_vec(), InMemoryGroupStateStorage::new() .with_max_epoch_retention(retention_limit) .unwrap(), InMemoryKeyPackageStorage::default(), None, ) .unwrap() } fn test_epoch(epoch_id: u64) -> PriorEpoch { get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn test_snapshot(epoch_id: u64) -> Snapshot { crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_epoch_inserts() { let mut test_repo = test_group_state_repo(1); let test_epoch = test_epoch(0); test_repo.insert(test_epoch.clone()).await.unwrap(); // Check the in-memory state assert_eq!( test_repo.pending_commit.inserts.back().unwrap(), &test_epoch ); assert!(test_repo.pending_commit.updates.is_empty()); #[cfg(feature = "std")] assert!(test_repo.storage.inner.lock().unwrap().is_empty()); #[cfg(not(feature = "std"))] assert!(test_repo.storage.inner.lock().is_empty()); let psk_id = ResumptionPsk { psk_epoch: 0, psk_group_id: PskGroupId(test_repo.group_id.clone()), usage: ResumptionPSKUsage::Application, }; // Make sure you can recall an epoch sitting as a pending insert let resumption = test_repo.resumption_secret(&psk_id).await.unwrap(); let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned(); assert_eq!( prior_epoch.clone().unwrap().secrets.resumption_secret, resumption.unwrap() ); assert_eq!(prior_epoch.unwrap(), test_epoch); // Write to the storage let snapshot = test_snapshot(test_epoch.epoch_id()).await; test_repo.write_to_storage(snapshot.clone()).await.unwrap(); // Make sure the memory cache cleared assert!(test_repo.pending_commit.inserts.is_empty()); assert!(test_repo.pending_commit.updates.is_empty()); // Make sure the storage was written #[cfg(feature = "std")] let storage = test_repo.storage.inner.lock().unwrap(); #[cfg(not(feature = "std"))] let storage = test_repo.storage.inner.lock(); assert_eq!(storage.len(), 1); let stored = storage.get(TEST_GROUP).unwrap(); assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); assert_eq!(stored.epoch_data.len(), 1); assert_eq!( stored.epoch_data.back().unwrap(), &EpochRecord::new( test_epoch.epoch_id(), test_epoch.mls_encode_to_vec().unwrap() ) ); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_updates() { let mut test_repo = test_group_state_repo(2); let test_epoch_0 = test_epoch(0); test_repo.insert(test_epoch_0.clone()).await.unwrap(); test_repo .write_to_storage(test_snapshot(0).await) .await .unwrap(); // Update the stored epoch let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap(); assert_eq!(to_update, &test_epoch_0); let new_sender_secret = random_bytes(32); to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret); let to_update = to_update.clone(); assert_eq!(test_repo.pending_commit.updates.len(), 1); assert!(test_repo.pending_commit.inserts.is_empty()); assert_eq!( test_repo.pending_commit.updates.first().unwrap(), &to_update ); // Make sure you can access an epoch pending update let psk_id = ResumptionPsk { psk_epoch: 0, psk_group_id: PskGroupId(test_repo.group_id.clone()), usage: ResumptionPSKUsage::Application, }; let owned = test_repo.resumption_secret(&psk_id).await.unwrap(); assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret)); // Write the update to storage let snapshot = test_snapshot(1).await; test_repo.write_to_storage(snapshot.clone()).await.unwrap(); assert!(test_repo.pending_commit.updates.is_empty()); assert!(test_repo.pending_commit.inserts.is_empty()); // Make sure the storage was written #[cfg(feature = "std")] let storage = test_repo.storage.inner.lock().unwrap(); #[cfg(not(feature = "std"))] let storage = test_repo.storage.inner.lock(); assert_eq!(storage.len(), 1); let stored = storage.get(TEST_GROUP).unwrap(); assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); assert_eq!(stored.epoch_data.len(), 1); assert_eq!( stored.epoch_data.back().unwrap(), &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap()) ); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_insert_and_update() { let mut test_repo = test_group_state_repo(2); let test_epoch_0 = test_epoch(0); test_repo.insert(test_epoch_0).await.unwrap(); test_repo .write_to_storage(test_snapshot(0).await) .await .unwrap(); // Update the stored epoch let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap(); let new_sender_secret = random_bytes(32); to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret); let to_update = to_update.clone(); // Insert another epoch let test_epoch_1 = test_epoch(1); test_repo.insert(test_epoch_1.clone()).await.unwrap(); test_repo .write_to_storage(test_snapshot(1).await) .await .unwrap(); assert!(test_repo.pending_commit.inserts.is_empty()); assert!(test_repo.pending_commit.updates.is_empty()); // Make sure the storage was written #[cfg(feature = "std")] let storage = test_repo.storage.inner.lock().unwrap(); #[cfg(not(feature = "std"))] let storage = test_repo.storage.inner.lock(); assert_eq!(storage.len(), 1); let stored = storage.get(TEST_GROUP).unwrap(); assert_eq!(stored.epoch_data.len(), 2); assert_eq!( stored.epoch_data.front().unwrap(), &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap()) ); assert_eq!( stored.epoch_data.back().unwrap(), &EpochRecord::new( test_epoch_1.epoch_id(), test_epoch_1.mls_encode_to_vec().unwrap() ) ); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_many_epochs_in_storage() { let epochs = (0..10).map(test_epoch).collect::>(); let mut test_repo = test_group_state_repo(10); for epoch in epochs.iter().cloned() { test_repo.insert(epoch).await.unwrap() } test_repo .write_to_storage(test_snapshot(9).await) .await .unwrap(); for mut epoch in epochs { let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap(); assert_eq!(res, Some(&mut epoch)); } } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_stored_groups_list() { let mut test_repo = test_group_state_repo(2); let test_epoch_0 = test_epoch(0); test_repo.insert(test_epoch_0.clone()).await.unwrap(); test_repo .write_to_storage(test_snapshot(0).await) .await .unwrap(); assert_eq!( test_repo.storage.stored_groups(), vec![test_epoch_0.context.group_id] ) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn reducing_retention_limit_takes_effect_on_epoch_access() { let mut repo = test_group_state_repo(1); repo.insert(test_epoch(0)).await.unwrap(); repo.insert(test_epoch(1)).await.unwrap(); repo.write_to_storage(test_snapshot(0).await).await.unwrap(); let mut repo = GroupStateRepository { storage: repo.storage, ..test_group_state_repo(1) }; let res = repo.get_epoch_mut(0).await.unwrap(); assert!(res.is_none()); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn in_memory_storage_obeys_retention_limit_after_saving() { let mut repo = test_group_state_repo(1); repo.insert(test_epoch(0)).await.unwrap(); repo.write_to_storage(test_snapshot(0).await).await.unwrap(); repo.insert(test_epoch(1)).await.unwrap(); repo.write_to_storage(test_snapshot(1).await).await.unwrap(); #[cfg(feature = "std")] let lock = repo.storage.inner.lock().unwrap(); #[cfg(not(feature = "std"))] let lock = repo.storage.inner.lock(); assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn used_key_package_is_deleted() { let key_package_repo = InMemoryKeyPackageStorage::default(); let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member") .await .0; let (id, data) = key_package.to_storage().unwrap(); key_package_repo.insert(id, data); let mut repo = GroupStateRepository::new( TEST_GROUP.to_vec(), InMemoryGroupStateStorage::new(), key_package_repo, Some(key_package.reference.clone()), ) .unwrap(); repo.key_package_repo.get(&key_package.reference).unwrap(); repo.write_to_storage(test_snapshot(4).await).await.unwrap(); assert!(repo.key_package_repo.get(&key_package.reference).is_none()); } }