1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::{ 6 client::MlsError, 7 client_config::ClientConfig, 8 group::{ 9 cipher_suite_provider, epoch::EpochSecrets, key_schedule::KeySchedule, 10 state_repo::GroupStateRepository, CommitGeneration, ConfirmationTag, Group, GroupContext, 11 GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic, 12 }, 13 tree_kem::TreeKemPrivate, 14 }; 15 16 #[cfg(feature = "by_ref_proposal")] 17 use crate::{ 18 crypto::{HpkePublicKey, HpkeSecretKey}, 19 group::{ 20 message_hash::MessageHash, 21 proposal_cache::{CachedProposal, ProposalCache}, 22 ProposalMessageDescription, ProposalRef, 23 }, 24 map::SmallMap, 25 }; 26 27 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 28 use mls_rs_core::crypto::SignatureSecretKey; 29 #[cfg(feature = "tree_index")] 30 use mls_rs_core::identity::IdentityProvider; 31 32 #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)] 33 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 34 pub(crate) struct Snapshot { 35 version: u16, 36 pub(crate) state: RawGroupState, 37 private_tree: TreeKemPrivate, 38 epoch_secrets: EpochSecrets, 39 key_schedule: KeySchedule, 40 #[cfg(feature = "by_ref_proposal")] 41 pending_updates: SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, 42 pending_commit: Option<CommitGeneration>, 43 signer: SignatureSecretKey, 44 } 45 46 #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)] 47 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 48 pub(crate) struct RawGroupState { 49 pub(crate) context: GroupContext, 50 #[cfg(feature = "by_ref_proposal")] 51 pub(crate) proposals: SmallMap<ProposalRef, CachedProposal>, 52 #[cfg(feature = "by_ref_proposal")] 53 pub(crate) own_proposals: SmallMap<MessageHash, ProposalMessageDescription>, 54 pub(crate) public_tree: TreeKemPublic, 55 pub(crate) interim_transcript_hash: InterimTranscriptHash, 56 pub(crate) pending_reinit: Option<ReInitProposal>, 57 pub(crate) confirmation_tag: ConfirmationTag, 58 } 59 60 impl RawGroupState { export(state: &GroupState) -> Self61 pub(crate) fn export(state: &GroupState) -> Self { 62 #[cfg(feature = "tree_index")] 63 let public_tree = state.public_tree.clone(); 64 65 #[cfg(not(feature = "tree_index"))] 66 let public_tree = { 67 let mut tree = TreeKemPublic::new(); 68 tree.nodes = state.public_tree.nodes.clone(); 69 tree 70 }; 71 72 Self { 73 context: state.context.clone(), 74 #[cfg(feature = "by_ref_proposal")] 75 proposals: state.proposals.proposals.clone(), 76 #[cfg(feature = "by_ref_proposal")] 77 own_proposals: state.proposals.own_proposals.clone(), 78 public_tree, 79 interim_transcript_hash: state.interim_transcript_hash.clone(), 80 pending_reinit: state.pending_reinit.clone(), 81 confirmation_tag: state.confirmation_tag.clone(), 82 } 83 } 84 85 #[cfg(feature = "tree_index")] 86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> where C: IdentityProvider,87 pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> 88 where 89 C: IdentityProvider, 90 { 91 let context = self.context; 92 93 #[cfg(feature = "by_ref_proposal")] 94 let proposals = ProposalCache::import( 95 context.protocol_version, 96 context.group_id.clone(), 97 self.proposals, 98 self.own_proposals.clone(), 99 ); 100 101 let mut public_tree = self.public_tree; 102 103 public_tree 104 .initialize_index_if_necessary(identity_provider, &context.extensions) 105 .await?; 106 107 Ok(GroupState { 108 #[cfg(feature = "by_ref_proposal")] 109 proposals, 110 context, 111 public_tree, 112 interim_transcript_hash: self.interim_transcript_hash, 113 pending_reinit: self.pending_reinit, 114 confirmation_tag: self.confirmation_tag, 115 }) 116 } 117 118 #[cfg(not(feature = "tree_index"))] 119 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] import(self) -> Result<GroupState, MlsError>120 pub(crate) async fn import(self) -> Result<GroupState, MlsError> { 121 let context = self.context; 122 123 #[cfg(feature = "by_ref_proposal")] 124 let proposals = ProposalCache::import( 125 context.protocol_version, 126 context.group_id.clone(), 127 self.proposals, 128 self.own_proposals.clone(), 129 ); 130 131 Ok(GroupState { 132 #[cfg(feature = "by_ref_proposal")] 133 proposals, 134 context, 135 public_tree: self.public_tree, 136 interim_transcript_hash: self.interim_transcript_hash, 137 pending_reinit: self.pending_reinit, 138 confirmation_tag: self.confirmation_tag, 139 }) 140 } 141 } 142 143 impl<C> Group<C> 144 where 145 C: ClientConfig + Clone, 146 { 147 /// Write the current state of the group to the 148 /// [`GroupStorageProvider`](crate::GroupStateStorage) 149 /// that is currently in use by the group. 150 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] write_to_storage(&mut self) -> Result<(), MlsError>151 pub async fn write_to_storage(&mut self) -> Result<(), MlsError> { 152 self.state_repo.write_to_storage(self.snapshot()).await 153 } 154 snapshot(&self) -> Snapshot155 pub(crate) fn snapshot(&self) -> Snapshot { 156 Snapshot { 157 state: RawGroupState::export(&self.state), 158 private_tree: self.private_tree.clone(), 159 key_schedule: self.key_schedule.clone(), 160 #[cfg(feature = "by_ref_proposal")] 161 pending_updates: self.pending_updates.clone(), 162 pending_commit: self.pending_commit.clone(), 163 epoch_secrets: self.epoch_secrets.clone(), 164 version: 1, 165 signer: self.signer.clone(), 166 } 167 } 168 169 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError>170 pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> { 171 let cipher_suite_provider = cipher_suite_provider( 172 config.crypto_provider(), 173 snapshot.state.context.cipher_suite, 174 )?; 175 176 #[cfg(feature = "tree_index")] 177 let identity_provider = config.identity_provider(); 178 179 let state_repo = GroupStateRepository::new( 180 #[cfg(feature = "prior_epoch")] 181 snapshot.state.context.group_id.clone(), 182 config.group_state_storage(), 183 config.key_package_repo(), 184 None, 185 )?; 186 187 Ok(Group { 188 config, 189 state: snapshot 190 .state 191 .import( 192 #[cfg(feature = "tree_index")] 193 &identity_provider, 194 ) 195 .await?, 196 private_tree: snapshot.private_tree, 197 key_schedule: snapshot.key_schedule, 198 #[cfg(feature = "by_ref_proposal")] 199 pending_updates: snapshot.pending_updates, 200 pending_commit: snapshot.pending_commit, 201 #[cfg(test)] 202 commit_modifiers: Default::default(), 203 epoch_secrets: snapshot.epoch_secrets, 204 state_repo, 205 cipher_suite_provider, 206 #[cfg(feature = "psk")] 207 previous_psk: None, 208 signer: snapshot.signer, 209 }) 210 } 211 } 212 213 #[cfg(test)] 214 pub(crate) mod test_utils { 215 use alloc::vec; 216 217 use crate::{ 218 cipher_suite::CipherSuite, 219 crypto::test_utils::test_cipher_suite_provider, 220 group::{ 221 confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets, 222 key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context, 223 transcript_hash::InterimTranscriptHash, 224 }, 225 tree_kem::{node::LeafIndex, TreeKemPrivate}, 226 }; 227 228 use super::{RawGroupState, Snapshot}; 229 230 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot231 pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot { 232 Snapshot { 233 state: RawGroupState { 234 context: get_test_group_context(epoch_id, cipher_suite).await, 235 #[cfg(feature = "by_ref_proposal")] 236 proposals: Default::default(), 237 #[cfg(feature = "by_ref_proposal")] 238 own_proposals: Default::default(), 239 public_tree: Default::default(), 240 interim_transcript_hash: InterimTranscriptHash::from(vec![]), 241 pending_reinit: None, 242 confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite)) 243 .await, 244 }, 245 private_tree: TreeKemPrivate::new(LeafIndex(0)), 246 epoch_secrets: get_test_epoch_secrets(cipher_suite), 247 key_schedule: get_test_key_schedule(cipher_suite), 248 #[cfg(feature = "by_ref_proposal")] 249 pending_updates: Default::default(), 250 pending_commit: None, 251 version: 1, 252 signer: vec![].into(), 253 } 254 } 255 } 256 257 #[cfg(test)] 258 mod tests { 259 use alloc::vec; 260 261 use crate::{ 262 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, 263 group::{ 264 test_utils::{test_group, TestGroup}, 265 Group, 266 }, 267 }; 268 269 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] snapshot_restore(group: TestGroup)270 async fn snapshot_restore(group: TestGroup) { 271 let snapshot = group.group.snapshot(); 272 273 let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot) 274 .await 275 .unwrap(); 276 277 assert!(Group::equal_group_state(&group.group, &group_restored)); 278 279 #[cfg(feature = "tree_index")] 280 assert!(group_restored 281 .state 282 .public_tree 283 .equal_internals(&group.group.state.public_tree)) 284 } 285 286 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] snapshot_with_pending_commit_can_be_serialized_to_json()287 async fn snapshot_with_pending_commit_can_be_serialized_to_json() { 288 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; 289 group.group.commit(vec![]).await.unwrap(); 290 291 snapshot_restore(group).await 292 } 293 294 #[cfg(feature = "by_ref_proposal")] 295 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] snapshot_with_pending_updates_can_be_serialized_to_json()296 async fn snapshot_with_pending_updates_can_be_serialized_to_json() { 297 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; 298 299 // Creating the update proposal will add it to pending updates 300 let update_proposal = group.update_proposal().await; 301 302 // This will insert the proposal into the internal proposal cache 303 let _ = group.group.proposal_message(update_proposal, vec![]).await; 304 305 snapshot_restore(group).await 306 } 307 308 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] snapshot_can_be_serialized_to_json_with_internals()309 async fn snapshot_can_be_serialized_to_json_with_internals() { 310 let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; 311 312 snapshot_restore(group).await 313 } 314 315 #[cfg(feature = "serde")] 316 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] serde()317 async fn serde() { 318 let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await; 319 let json = serde_json::to_string_pretty(&snapshot).unwrap(); 320 let recovered = serde_json::from_str(&json).unwrap(); 321 assert_eq!(snapshot, recovered); 322 } 323 } 324