• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::collections::VecDeque;
6 
7 #[cfg(target_has_atomic = "ptr")]
8 use alloc::sync::Arc;
9 
10 #[cfg(mls_build_async)]
11 use alloc::boxed::Box;
12 use alloc::vec::Vec;
13 use core::{
14     convert::Infallible,
15     fmt::{self, Debug},
16 };
17 use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
18 #[cfg(not(target_has_atomic = "ptr"))]
19 use portable_atomic_util::Arc;
20 
21 use crate::{
22     client::MlsError,
23     map::{LargeMap, LargeMapEntry},
24 };
25 
26 #[cfg(feature = "std")]
27 use std::sync::{Mutex, MutexGuard};
28 
29 #[cfg(not(feature = "std"))]
30 use spin::{Mutex, MutexGuard};
31 
32 pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3;
33 
34 #[derive(Clone)]
35 pub(crate) struct InMemoryGroupData {
36     pub(crate) state_data: Vec<u8>,
37     pub(crate) epoch_data: VecDeque<EpochRecord>,
38 }
39 
40 impl Debug for InMemoryGroupData {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result41     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42         f.debug_struct("InMemoryGroupData")
43             .field(
44                 "state_data",
45                 &mls_rs_core::debug::pretty_bytes(&self.state_data),
46             )
47             .field("epoch_data", &self.epoch_data)
48             .finish()
49     }
50 }
51 
52 impl InMemoryGroupData {
new(state_data: Vec<u8>) -> InMemoryGroupData53     pub fn new(state_data: Vec<u8>) -> InMemoryGroupData {
54         InMemoryGroupData {
55             state_data,
56             epoch_data: Default::default(),
57         }
58     }
59 
get_epoch_data_index(&self, epoch_id: u64) -> Option<u64>60     fn get_epoch_data_index(&self, epoch_id: u64) -> Option<u64> {
61         self.epoch_data
62             .front()
63             .and_then(|e| epoch_id.checked_sub(e.id))
64     }
65 
get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord>66     pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> {
67         self.get_epoch_data_index(epoch_id)
68             .and_then(|i| self.epoch_data.get(i as usize))
69     }
70 
get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord>71     pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> {
72         self.get_epoch_data_index(epoch_id)
73             .and_then(|i| self.epoch_data.get_mut(i as usize))
74     }
75 
insert_epoch(&mut self, epoch: EpochRecord)76     pub fn insert_epoch(&mut self, epoch: EpochRecord) {
77         self.epoch_data.push_back(epoch)
78     }
79 
80     // This function does not fail if an update can't be made. If the epoch
81     // is not in the store, then it can no longer be accessed by future
82     // get_epoch calls and is no longer relevant.
update_epoch(&mut self, epoch: EpochRecord)83     pub fn update_epoch(&mut self, epoch: EpochRecord) {
84         if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) {
85             *existing_epoch = epoch
86         }
87     }
88 
trim_epochs(&mut self, max_epoch_retention: usize)89     pub fn trim_epochs(&mut self, max_epoch_retention: usize) {
90         while self.epoch_data.len() > max_epoch_retention {
91             self.epoch_data.pop_front();
92         }
93     }
94 }
95 
96 #[derive(Clone)]
97 /// In memory group state storage backed by a HashMap.
98 ///
99 /// All clones of an instance of this type share the same underlying HashMap.
100 pub struct InMemoryGroupStateStorage {
101     pub(crate) inner: Arc<Mutex<LargeMap<Vec<u8>, InMemoryGroupData>>>,
102     pub(crate) max_epoch_retention: usize,
103 }
104 
105 impl Debug for InMemoryGroupStateStorage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result106     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107         f.debug_struct("InMemoryGroupStateStorage")
108             .field(
109                 "inner",
110                 &mls_rs_core::debug::pretty_with(|f| {
111                     f.debug_map()
112                         .entries(
113                             self.lock()
114                                 .iter()
115                                 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
116                         )
117                         .finish()
118                 }),
119             )
120             .field("max_epoch_retention", &self.max_epoch_retention)
121             .finish()
122     }
123 }
124 
125 impl InMemoryGroupStateStorage {
126     /// Create an empty group state storage.
new() -> Self127     pub fn new() -> Self {
128         Self {
129             inner: Default::default(),
130             max_epoch_retention: DEFAULT_EPOCH_RETENTION_LIMIT,
131         }
132     }
133 
with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError>134     pub fn with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError> {
135         (max_epoch_retention > 0)
136             .then_some(())
137             .ok_or(MlsError::NonZeroRetentionRequired)?;
138 
139         Ok(Self {
140             inner: self.inner,
141             max_epoch_retention,
142         })
143     }
144 
145     /// Get the set of unique group ids that have data stored.
stored_groups(&self) -> Vec<Vec<u8>>146     pub fn stored_groups(&self) -> Vec<Vec<u8>> {
147         self.lock().keys().cloned().collect()
148     }
149 
150     /// Delete all data corresponding to `group_id`.
delete_group(&self, group_id: &[u8])151     pub fn delete_group(&self, group_id: &[u8]) {
152         self.lock().remove(group_id);
153     }
154 
lock(&self) -> MutexGuard<'_, LargeMap<Vec<u8>, InMemoryGroupData>>155     fn lock(&self) -> MutexGuard<'_, LargeMap<Vec<u8>, InMemoryGroupData>> {
156         #[cfg(feature = "std")]
157         return self.inner.lock().unwrap();
158 
159         #[cfg(not(feature = "std"))]
160         return self.inner.lock();
161     }
162 }
163 
164 impl Default for InMemoryGroupStateStorage {
default() -> Self165     fn default() -> Self {
166         Self::new()
167     }
168 }
169 
170 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
171 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
172 impl GroupStateStorage for InMemoryGroupStateStorage {
173     type Error = Infallible;
174 
max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error>175     async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
176         Ok(self
177             .lock()
178             .get(group_id)
179             .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id)))
180     }
181 
state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>182     async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
183         Ok(self
184             .lock()
185             .get(group_id)
186             .map(|data| data.state_data.clone()))
187     }
188 
epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error>189     async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> {
190         Ok(self
191             .lock()
192             .get(group_id)
193             .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone())))
194     }
195 
write( &mut self, state: GroupState, epoch_inserts: Vec<EpochRecord>, epoch_updates: Vec<EpochRecord>, ) -> Result<(), Self::Error>196     async fn write(
197         &mut self,
198         state: GroupState,
199         epoch_inserts: Vec<EpochRecord>,
200         epoch_updates: Vec<EpochRecord>,
201     ) -> Result<(), Self::Error> {
202         let mut group_map = self.lock();
203 
204         let group_data = match group_map.entry(state.id) {
205             LargeMapEntry::Occupied(entry) => {
206                 let data = entry.into_mut();
207                 data.state_data = state.data;
208                 data
209             }
210             LargeMapEntry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)),
211         };
212 
213         epoch_inserts
214             .into_iter()
215             .for_each(|e| group_data.insert_epoch(e));
216 
217         epoch_updates
218             .into_iter()
219             .for_each(|e| group_data.update_epoch(e));
220 
221         group_data.trim_epochs(self.max_epoch_retention);
222 
223         Ok(())
224     }
225 }
226 
227 #[cfg(all(test, feature = "prior_epoch"))]
228 mod tests {
229     use alloc::{format, vec, vec::Vec};
230     use assert_matches::assert_matches;
231 
232     use super::{InMemoryGroupData, InMemoryGroupStateStorage};
233     use crate::{client::MlsError, group::test_utils::TEST_GROUP};
234 
235     use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
236 
237     impl InMemoryGroupStateStorage {
test_data(&self) -> InMemoryGroupData238         fn test_data(&self) -> InMemoryGroupData {
239             self.lock().get(TEST_GROUP).unwrap().clone()
240         }
241     }
242 
test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError>243     fn test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError> {
244         InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit)
245     }
246 
test_epoch(epoch_id: u64) -> EpochRecord247     fn test_epoch(epoch_id: u64) -> EpochRecord {
248         EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec())
249     }
250 
test_snapshot(epoch_id: u64) -> GroupState251     fn test_snapshot(epoch_id: u64) -> GroupState {
252         GroupState {
253             id: TEST_GROUP.into(),
254             data: format!("snapshot {epoch_id}").as_bytes().to_vec(),
255         }
256     }
257 
258     #[test]
test_zero_max_retention()259     fn test_zero_max_retention() {
260         assert_matches!(test_storage(0), Err(MlsError::NonZeroRetentionRequired))
261     }
262 
263     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
existing_storage_can_have_larger_epoch_count()264     async fn existing_storage_can_have_larger_epoch_count() {
265         let mut storage = test_storage(2).unwrap();
266 
267         let epoch_inserts = vec![test_epoch(0), test_epoch(1)];
268 
269         storage
270             .write(test_snapshot(0), epoch_inserts, Vec::new())
271             .await
272             .unwrap();
273 
274         assert_eq!(storage.test_data().epoch_data.len(), 2);
275 
276         storage.max_epoch_retention = 4;
277 
278         let epoch_inserts = vec![test_epoch(3), test_epoch(4)];
279 
280         storage
281             .write(test_snapshot(1), epoch_inserts, Vec::new())
282             .await
283             .unwrap();
284 
285         assert_eq!(storage.test_data().epoch_data.len(), 4);
286     }
287 
288     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
existing_storage_can_have_smaller_epoch_count()289     async fn existing_storage_can_have_smaller_epoch_count() {
290         let mut storage = test_storage(4).unwrap();
291 
292         let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)];
293 
294         storage
295             .write(test_snapshot(1), epoch_inserts, Vec::new())
296             .await
297             .unwrap();
298 
299         assert_eq!(storage.test_data().epoch_data.len(), 4);
300 
301         storage.max_epoch_retention = 2;
302 
303         let epoch_inserts = vec![test_epoch(5)];
304 
305         storage
306             .write(test_snapshot(1), epoch_inserts, Vec::new())
307             .await
308             .unwrap();
309 
310         assert_eq!(storage.test_data().epoch_data.len(), 2);
311     }
312 
313     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
epoch_insert_over_limit()314     async fn epoch_insert_over_limit() {
315         test_epoch_insert_over_limit(false).await
316     }
317 
318     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
epoch_insert_over_limit_with_update()319     async fn epoch_insert_over_limit_with_update() {
320         test_epoch_insert_over_limit(true).await
321     }
322 
323     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_epoch_insert_over_limit(with_update: bool)324     async fn test_epoch_insert_over_limit(with_update: bool) {
325         let mut storage = test_storage(1).unwrap();
326 
327         let mut epoch_inserts = vec![test_epoch(0), test_epoch(1)];
328         let updates = with_update
329             .then_some(vec![test_epoch(0)])
330             .unwrap_or_default();
331         let snapshot = test_snapshot(1);
332 
333         storage
334             .write(snapshot.clone(), epoch_inserts.clone(), updates)
335             .await
336             .unwrap();
337 
338         let stored = storage.test_data();
339 
340         assert_eq!(stored.state_data, snapshot.data);
341         assert_eq!(stored.epoch_data.len(), 1);
342 
343         let expected = epoch_inserts.pop().unwrap();
344         assert_eq!(stored.epoch_data[0], expected);
345     }
346 }
347