• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::{
6     client::MlsError,
7     client_config::ClientConfig,
8     group::{
9         cipher_suite_provider, epoch::EpochSecrets, key_schedule::KeySchedule,
10         state_repo::GroupStateRepository, CommitGeneration, ConfirmationTag, Group, GroupContext,
11         GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
12     },
13     tree_kem::TreeKemPrivate,
14 };
15 
16 #[cfg(feature = "by_ref_proposal")]
17 use crate::{
18     crypto::{HpkePublicKey, HpkeSecretKey},
19     group::{
20         message_hash::MessageHash,
21         proposal_cache::{CachedProposal, ProposalCache},
22         ProposalMessageDescription, ProposalRef,
23     },
24     map::SmallMap,
25 };
26 
27 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
28 use mls_rs_core::crypto::SignatureSecretKey;
29 #[cfg(feature = "tree_index")]
30 use mls_rs_core::identity::IdentityProvider;
31 
32 #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
33 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34 pub(crate) struct Snapshot {
35     version: u16,
36     pub(crate) state: RawGroupState,
37     private_tree: TreeKemPrivate,
38     epoch_secrets: EpochSecrets,
39     key_schedule: KeySchedule,
40     #[cfg(feature = "by_ref_proposal")]
41     pending_updates: SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
42     pending_commit: Option<CommitGeneration>,
43     signer: SignatureSecretKey,
44 }
45 
46 #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
47 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48 pub(crate) struct RawGroupState {
49     pub(crate) context: GroupContext,
50     #[cfg(feature = "by_ref_proposal")]
51     pub(crate) proposals: SmallMap<ProposalRef, CachedProposal>,
52     #[cfg(feature = "by_ref_proposal")]
53     pub(crate) own_proposals: SmallMap<MessageHash, ProposalMessageDescription>,
54     pub(crate) public_tree: TreeKemPublic,
55     pub(crate) interim_transcript_hash: InterimTranscriptHash,
56     pub(crate) pending_reinit: Option<ReInitProposal>,
57     pub(crate) confirmation_tag: ConfirmationTag,
58 }
59 
60 impl RawGroupState {
export(state: &GroupState) -> Self61     pub(crate) fn export(state: &GroupState) -> Self {
62         #[cfg(feature = "tree_index")]
63         let public_tree = state.public_tree.clone();
64 
65         #[cfg(not(feature = "tree_index"))]
66         let public_tree = {
67             let mut tree = TreeKemPublic::new();
68             tree.nodes = state.public_tree.nodes.clone();
69             tree
70         };
71 
72         Self {
73             context: state.context.clone(),
74             #[cfg(feature = "by_ref_proposal")]
75             proposals: state.proposals.proposals.clone(),
76             #[cfg(feature = "by_ref_proposal")]
77             own_proposals: state.proposals.own_proposals.clone(),
78             public_tree,
79             interim_transcript_hash: state.interim_transcript_hash.clone(),
80             pending_reinit: state.pending_reinit.clone(),
81             confirmation_tag: state.confirmation_tag.clone(),
82         }
83     }
84 
85     #[cfg(feature = "tree_index")]
86     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> where C: IdentityProvider,87     pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
88     where
89         C: IdentityProvider,
90     {
91         let context = self.context;
92 
93         #[cfg(feature = "by_ref_proposal")]
94         let proposals = ProposalCache::import(
95             context.protocol_version,
96             context.group_id.clone(),
97             self.proposals,
98             self.own_proposals.clone(),
99         );
100 
101         let mut public_tree = self.public_tree;
102 
103         public_tree
104             .initialize_index_if_necessary(identity_provider, &context.extensions)
105             .await?;
106 
107         Ok(GroupState {
108             #[cfg(feature = "by_ref_proposal")]
109             proposals,
110             context,
111             public_tree,
112             interim_transcript_hash: self.interim_transcript_hash,
113             pending_reinit: self.pending_reinit,
114             confirmation_tag: self.confirmation_tag,
115         })
116     }
117 
118     #[cfg(not(feature = "tree_index"))]
119     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
import(self) -> Result<GroupState, MlsError>120     pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
121         let context = self.context;
122 
123         #[cfg(feature = "by_ref_proposal")]
124         let proposals = ProposalCache::import(
125             context.protocol_version,
126             context.group_id.clone(),
127             self.proposals,
128             self.own_proposals.clone(),
129         );
130 
131         Ok(GroupState {
132             #[cfg(feature = "by_ref_proposal")]
133             proposals,
134             context,
135             public_tree: self.public_tree,
136             interim_transcript_hash: self.interim_transcript_hash,
137             pending_reinit: self.pending_reinit,
138             confirmation_tag: self.confirmation_tag,
139         })
140     }
141 }
142 
143 impl<C> Group<C>
144 where
145     C: ClientConfig + Clone,
146 {
147     /// Write the current state of the group to the
148     /// [`GroupStorageProvider`](crate::GroupStateStorage)
149     /// that is currently in use by the group.
150     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
write_to_storage(&mut self) -> Result<(), MlsError>151     pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
152         self.state_repo.write_to_storage(self.snapshot()).await
153     }
154 
snapshot(&self) -> Snapshot155     pub(crate) fn snapshot(&self) -> Snapshot {
156         Snapshot {
157             state: RawGroupState::export(&self.state),
158             private_tree: self.private_tree.clone(),
159             key_schedule: self.key_schedule.clone(),
160             #[cfg(feature = "by_ref_proposal")]
161             pending_updates: self.pending_updates.clone(),
162             pending_commit: self.pending_commit.clone(),
163             epoch_secrets: self.epoch_secrets.clone(),
164             version: 1,
165             signer: self.signer.clone(),
166         }
167     }
168 
169     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError>170     pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
171         let cipher_suite_provider = cipher_suite_provider(
172             config.crypto_provider(),
173             snapshot.state.context.cipher_suite,
174         )?;
175 
176         #[cfg(feature = "tree_index")]
177         let identity_provider = config.identity_provider();
178 
179         let state_repo = GroupStateRepository::new(
180             #[cfg(feature = "prior_epoch")]
181             snapshot.state.context.group_id.clone(),
182             config.group_state_storage(),
183             config.key_package_repo(),
184             None,
185         )?;
186 
187         Ok(Group {
188             config,
189             state: snapshot
190                 .state
191                 .import(
192                     #[cfg(feature = "tree_index")]
193                     &identity_provider,
194                 )
195                 .await?,
196             private_tree: snapshot.private_tree,
197             key_schedule: snapshot.key_schedule,
198             #[cfg(feature = "by_ref_proposal")]
199             pending_updates: snapshot.pending_updates,
200             pending_commit: snapshot.pending_commit,
201             #[cfg(test)]
202             commit_modifiers: Default::default(),
203             epoch_secrets: snapshot.epoch_secrets,
204             state_repo,
205             cipher_suite_provider,
206             #[cfg(feature = "psk")]
207             previous_psk: None,
208             signer: snapshot.signer,
209         })
210     }
211 }
212 
213 #[cfg(test)]
214 pub(crate) mod test_utils {
215     use alloc::vec;
216 
217     use crate::{
218         cipher_suite::CipherSuite,
219         crypto::test_utils::test_cipher_suite_provider,
220         group::{
221             confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
222             key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
223             transcript_hash::InterimTranscriptHash,
224         },
225         tree_kem::{node::LeafIndex, TreeKemPrivate},
226     };
227 
228     use super::{RawGroupState, Snapshot};
229 
230     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot231     pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
232         Snapshot {
233             state: RawGroupState {
234                 context: get_test_group_context(epoch_id, cipher_suite).await,
235                 #[cfg(feature = "by_ref_proposal")]
236                 proposals: Default::default(),
237                 #[cfg(feature = "by_ref_proposal")]
238                 own_proposals: Default::default(),
239                 public_tree: Default::default(),
240                 interim_transcript_hash: InterimTranscriptHash::from(vec![]),
241                 pending_reinit: None,
242                 confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
243                     .await,
244             },
245             private_tree: TreeKemPrivate::new(LeafIndex(0)),
246             epoch_secrets: get_test_epoch_secrets(cipher_suite),
247             key_schedule: get_test_key_schedule(cipher_suite),
248             #[cfg(feature = "by_ref_proposal")]
249             pending_updates: Default::default(),
250             pending_commit: None,
251             version: 1,
252             signer: vec![].into(),
253         }
254     }
255 }
256 
257 #[cfg(test)]
258 mod tests {
259     use alloc::vec;
260 
261     use crate::{
262         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
263         group::{
264             test_utils::{test_group, TestGroup},
265             Group,
266         },
267     };
268 
269     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
snapshot_restore(group: TestGroup)270     async fn snapshot_restore(group: TestGroup) {
271         let snapshot = group.group.snapshot();
272 
273         let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot)
274             .await
275             .unwrap();
276 
277         assert!(Group::equal_group_state(&group.group, &group_restored));
278 
279         #[cfg(feature = "tree_index")]
280         assert!(group_restored
281             .state
282             .public_tree
283             .equal_internals(&group.group.state.public_tree))
284     }
285 
286     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
snapshot_with_pending_commit_can_be_serialized_to_json()287     async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
288         let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
289         group.group.commit(vec![]).await.unwrap();
290 
291         snapshot_restore(group).await
292     }
293 
294     #[cfg(feature = "by_ref_proposal")]
295     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
snapshot_with_pending_updates_can_be_serialized_to_json()296     async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
297         let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
298 
299         // Creating the update proposal will add it to pending updates
300         let update_proposal = group.update_proposal().await;
301 
302         // This will insert the proposal into the internal proposal cache
303         let _ = group.group.proposal_message(update_proposal, vec![]).await;
304 
305         snapshot_restore(group).await
306     }
307 
308     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
snapshot_can_be_serialized_to_json_with_internals()309     async fn snapshot_can_be_serialized_to_json_with_internals() {
310         let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
311 
312         snapshot_restore(group).await
313     }
314 
315     #[cfg(feature = "serde")]
316     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
serde()317     async fn serde() {
318         let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
319         let json = serde_json::to_string_pretty(&snapshot).unwrap();
320         let recovered = serde_json::from_str(&json).unwrap();
321         assert_eq!(snapshot, recovered);
322     }
323 }
324